From 0fc15089039c758f9cfacb7c7cb824102c7b59c1 Mon Sep 17 00:00:00 2001 From: Haoqin Tu Date: Thu, 23 May 2024 09:59:05 +0800 Subject: [PATCH] Add Automatic GPT4V Evaluation for VLM Originality Evaluation (#2576) --- .../gpt4v_originality_critique_metrics.py | 126 ++++++++++++++++++ src/helm/benchmark/run_specs/vlm_run_specs.py | 24 +++- .../vision_language/mementos_scenario.py | 4 +- src/helm/clients/openai_client.py | 1 - src/helm/common/critique_request.py | 8 +- .../proxy/critique/model_critique_client.py | 11 ++ 6 files changed, 168 insertions(+), 6 deletions(-) create mode 100644 src/helm/benchmark/metrics/gpt4v_originality_critique_metrics.py diff --git a/src/helm/benchmark/metrics/gpt4v_originality_critique_metrics.py b/src/helm/benchmark/metrics/gpt4v_originality_critique_metrics.py new file mode 100644 index 00000000000..5c45d1b583f --- /dev/null +++ b/src/helm/benchmark/metrics/gpt4v_originality_critique_metrics.py @@ -0,0 +1,126 @@ +from typing import Dict, List + +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.adaptation.scenario_state import ScenarioState +from helm.benchmark.adaptation.adapter_spec import AdapterSpec +from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats, add_context +from helm.benchmark.metrics.metric_name import MetricContext, MetricName +from helm.benchmark.metrics.metric_service import MetricService +from helm.benchmark.metrics.statistic import Stat, merge_stat +from helm.common.critique_request import CritiqueTaskTemplate, CritiqueQuestionTemplate, CritiqueRequest, QuestionType +from helm.common.hierarchical_logger import hlog +from helm.common.request import RequestResult, Request, GeneratedOutput +from helm.common.media_object import MultimediaObject, IMAGE_TYPE, MediaObject + + +class GPT4VCritiqueMetric(MetricInterface): + """ + Critique evaluation for evaluating how original the generated text are given the image by GPT4V. + """ + + # We can add more evaluation aspects here + ORIGINALITY_NAME: str = "originality_gpt4v" + ORIGINALITY_ANSWER_TO_SCORE: Dict[str, int] = { + "I’ve seen something like this before to the point it’s become tiresome.": 1, + "The text is not really original, but it has some originality to it.": 2, + "Neutral.": 3, + "I find the text to be fresh and original.": 4, + "I find the text to be extremely creative and out of this world.": 5, + } + + def __init__(self, num_respondents: int): + self._num_respondents = num_respondents + + def __repr__(self) -> str: + return "GPT4CritiqueMetric()" + + def evaluate( + self, + scenario_state: ScenarioState, + metric_service: MetricService, + eval_cache_path: str, + parallelism: int, + ) -> MetricResult: + request_states: List[RequestState] = scenario_state.request_states + + all_stats: Dict[MetricName, Stat] = {} + per_instance_stats: List[PerInstanceStats] = [] + for request_state in request_states: + context = MetricContext.from_instance(request_state.instance) + stats_without_context = self.evaluate_generation( + scenario_state.adapter_spec, + request_state, + metric_service, + eval_cache_path, + ) + stats = [add_context(stat_without_context, context) for stat_without_context in stats_without_context] + for stat in stats: + merge_stat(all_stats, stat) + assert request_state.instance.id is not None + per_instance_stats.append( + PerInstanceStats( + instance_id=request_state.instance.id, + perturbation=request_state.instance.perturbation, + train_trial_index=request_state.train_trial_index, + stats=stats, + ) + ) + return MetricResult(aggregated_stats=list(all_stats.values()), per_instance_stats=per_instance_stats) + + def evaluate_generation( + self, + adapter_spec: AdapterSpec, + request_state: RequestState, + metric_service: MetricService, + eval_cache_path: str, + ) -> List[Stat]: + input_request: Request = request_state.request + # Predicted outputs and their originality scores + assert request_state.result is not None + request_result: RequestResult = request_state.result + # Get input image and generated response for the originality evaluation + assert input_request.multimodal_prompt is not None + completions: List[GeneratedOutput] = request_result.completions + input_text: str = completions[0].text + input_media: MultimediaObject = input_request.multimodal_prompt + image_objects: List[MediaObject] = [ + item for item in input_media.media_objects if item.is_type(IMAGE_TYPE) and item.location + ] + + template = CritiqueTaskTemplate( + name="vhelm_gpt4v_originality", + # TODO: Add proper instructions + instructions="Answer the question given the text and image, remember to only answer " + "with a capital letter.\n\n{{prompt}}", + num_respondents=self._num_respondents, + questions=[ + CritiqueQuestionTemplate( + name=self.ORIGINALITY_NAME, + question_type=QuestionType.MULTIPLE_CHOICE, + text="How original is the text, given it was created with the image?", + options=list(self.ORIGINALITY_ANSWER_TO_SCORE.keys()), + media_object=image_objects[0], # we only take the first image as input + ) + ], + ) + request = CritiqueRequest(template=template, fields={"prompt": input_text}) + + # send to critique request + result = metric_service.make_critique_request(request) + if not result or not result.responses: + # Skip computing metrics if there aren't any responses yet + hlog("Waiting for responses to be generated.") + return [] + + stats: Dict[str, Stat] = {} + for question in template.questions: + stats[question.name] = Stat(MetricName(question.name)) + + for response in result.responses: + for answer_name, answer in response.answers.items(): + assert isinstance(answer, str) + answer_value: float + answer_value = self.ORIGINALITY_ANSWER_TO_SCORE[answer] + stats[answer_name].add(answer_value) + + return list(stats.values()) diff --git a/src/helm/benchmark/run_specs/vlm_run_specs.py b/src/helm/benchmark/run_specs/vlm_run_specs.py index 7707d96b343..b36dcf77bcd 100644 --- a/src/helm/benchmark/run_specs/vlm_run_specs.py +++ b/src/helm/benchmark/run_specs/vlm_run_specs.py @@ -70,6 +70,13 @@ def _get_captioning_adapter_spec() -> AdapterSpec: ) +def get_open_end_answer_generation_adapter_spec(): + return _get_generation_adapter_spec( + instructions="Follow the given instruction and give your complete answer.", + max_tokens=100, + ) + + def _get_multiple_choice_joint_adapter_spec( input_noun: Optional[str], output_noun: str, @@ -139,6 +146,17 @@ def _get_image2structure_metric_specs( return metric_specs + get_basic_reference_metric_specs() +def get_gpt4v_critique_originality_metric_specs(num_respondents: int) -> List[MetricSpec]: + return [ + MetricSpec( + class_name="helm.benchmark.metrics.gpt4v_originality_critique_metrics.GPT4VCritiqueMetric", + args={ + "num_respondents": num_respondents, + }, + ) + ] + + ############################################################ # VHELM run specs @@ -739,13 +757,13 @@ def get_pairs_spec(subset: str, person: str) -> RunSpec: @run_spec_function("mementos") -def get_mementos_spec(subject: str) -> RunSpec: +def get_mementos_spec(subject: str, num_respondents: int) -> RunSpec: scenario_spec = ScenarioSpec( class_name="helm.benchmark.scenarios.vision_language.mementos_scenario.MementosScenario", args={"subject": subject}, ) - adapter_spec: AdapterSpec = _get_short_answer_generation_adapter_spec() - metric_specs: List[MetricSpec] = _get_open_ended_generation_metric_specs() + adapter_spec: AdapterSpec = get_open_end_answer_generation_adapter_spec() + metric_specs: List[MetricSpec] = get_gpt4v_critique_originality_metric_specs(num_respondents=num_respondents) run_spec_name: str = "mementos" return RunSpec( diff --git a/src/helm/benchmark/scenarios/vision_language/mementos_scenario.py b/src/helm/benchmark/scenarios/vision_language/mementos_scenario.py index f659a88a495..d077fde0c7b 100644 --- a/src/helm/benchmark/scenarios/vision_language/mementos_scenario.py +++ b/src/helm/benchmark/scenarios/vision_language/mementos_scenario.py @@ -51,6 +51,8 @@ class MementosScenario(Scenario): "Write a description for the given image sequence in a single paragraph, what is happening in this episode?" ) + ORIGINALITY_QUESTION_PROMPT: str = "Write a creative and original story for the given image sequence." + SUBJECTS: List[str] = ["comics", "dailylife", "robotics"] name = "mementos" @@ -98,7 +100,7 @@ def get_instances(self, output_path: str) -> List[Instance]: content: List[MediaObject] = [ MediaObject(location=local_image_path, content_type="image/png"), - MediaObject(text=self.QUESTION_PROMPT, content_type="text/plain"), + MediaObject(text=self.ORIGINALITY_QUESTION_PROMPT, content_type="text/plain"), ] answer: str = row["description"] instances.append( diff --git a/src/helm/clients/openai_client.py b/src/helm/clients/openai_client.py index faa3dec1e2d..1b1162fbe98 100644 --- a/src/helm/clients/openai_client.py +++ b/src/helm/clients/openai_client.py @@ -63,7 +63,6 @@ def _get_cache_key(self, raw_request: Dict, request: Request): if request.multimodal_prompt: prompt_key: str = generate_uid_for_multimodal_prompt(request.multimodal_prompt) cache_key = {**cache_key, "multimodal_prompt": prompt_key} - assert not cache_key["messages"] del cache_key["messages"] return cache_key diff --git a/src/helm/common/critique_request.py b/src/helm/common/critique_request.py index bc708ff5eef..677718149ba 100644 --- a/src/helm/common/critique_request.py +++ b/src/helm/common/critique_request.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Dict, List, Union +from typing import Dict, List, Union, Optional +from helm.common.media_object import MediaObject class QuestionType: @@ -34,6 +35,11 @@ class CritiqueQuestionTemplate: Can contain placeholders like {{placeholder}} that will be interpolated using the fields in CritiqueRequest.""" + media_object: Optional[MediaObject] = None + """Path of image for multimodal input. + + Image path or URL of the question.""" + @dataclass(frozen=True) class CritiqueTaskTemplate: diff --git a/src/helm/proxy/critique/model_critique_client.py b/src/helm/proxy/critique/model_critique_client.py index 0f843381654..809672037a6 100644 --- a/src/helm/proxy/critique/model_critique_client.py +++ b/src/helm/proxy/critique/model_critique_client.py @@ -15,6 +15,7 @@ from helm.common.request import Request, RequestResult, GeneratedOutput from helm.clients.client import Client from helm.proxy.critique.critique_client import CritiqueClient +from helm.common.media_object import MultimediaObject, MediaObject class CritiqueParseError(Exception): @@ -31,6 +32,7 @@ def __init__(self, client: Client, model_name): get_default_model_deployment_for_model(model_name, warn_arg_deprecated=False, ignore_deprecated=True) or self._model_name ) + self.vision_language = True if model_name.startswith("openai/gpt-4-vision") else False def _interpolate_fields(self, text: str, fields: Dict[str, str]) -> str: for key, value in fields.items(): @@ -78,12 +80,21 @@ def _task_to_requests(self, task: CritiqueTaskTemplate, fields: Dict[str, str]) prompt = anthropic.HUMAN_PROMPT + prompt + anthropic.AI_PROMPT + multimodal_prompt: Optional[MultimediaObject] = None + if self.vision_language: + assert question.media_object is not None, "Expect media_object for vision-language models" + image_media: MediaObject = question.media_object + text_media: MediaObject = MediaObject(text=prompt, content_type="text/plain") + multimodal_prompt = MultimediaObject(media_objects=[image_media, text_media]) + prompt = "" # set to empty string to avoid conflicts with multimodal_prompt + request = Request( model=self._model_name, model_deployment=self._model_deployment_name, prompt=prompt, max_tokens=max_tokens, echo_prompt=False, + multimodal_prompt=multimodal_prompt, ) requests.append(request) return requests