Skip to content

Commit

Permalink
Add Automatic GPT4V Evaluation for VLM Originality Evaluation (stanfo…
Browse files Browse the repository at this point in the history
  • Loading branch information
ImKeTT authored and xuwangyin committed Jun 23, 2024
1 parent 7bc5e2f commit 0fc1508
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 6 deletions.
126 changes: 126 additions & 0 deletions src/helm/benchmark/metrics/gpt4v_originality_critique_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Dict, List

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats, add_context
from helm.benchmark.metrics.metric_name import MetricContext, MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat, merge_stat
from helm.common.critique_request import CritiqueTaskTemplate, CritiqueQuestionTemplate, CritiqueRequest, QuestionType
from helm.common.hierarchical_logger import hlog
from helm.common.request import RequestResult, Request, GeneratedOutput
from helm.common.media_object import MultimediaObject, IMAGE_TYPE, MediaObject


class GPT4VCritiqueMetric(MetricInterface):
"""
Critique evaluation for evaluating how original the generated text are given the image by GPT4V.
"""

# We can add more evaluation aspects here
ORIGINALITY_NAME: str = "originality_gpt4v"
ORIGINALITY_ANSWER_TO_SCORE: Dict[str, int] = {
"I’ve seen something like this before to the point it’s become tiresome.": 1,
"The text is not really original, but it has some originality to it.": 2,
"Neutral.": 3,
"I find the text to be fresh and original.": 4,
"I find the text to be extremely creative and out of this world.": 5,
}

def __init__(self, num_respondents: int):
self._num_respondents = num_respondents

def __repr__(self) -> str:
return "GPT4CritiqueMetric()"

def evaluate(
self,
scenario_state: ScenarioState,
metric_service: MetricService,
eval_cache_path: str,
parallelism: int,
) -> MetricResult:
request_states: List[RequestState] = scenario_state.request_states

all_stats: Dict[MetricName, Stat] = {}
per_instance_stats: List[PerInstanceStats] = []
for request_state in request_states:
context = MetricContext.from_instance(request_state.instance)
stats_without_context = self.evaluate_generation(
scenario_state.adapter_spec,
request_state,
metric_service,
eval_cache_path,
)
stats = [add_context(stat_without_context, context) for stat_without_context in stats_without_context]
for stat in stats:
merge_stat(all_stats, stat)
assert request_state.instance.id is not None
per_instance_stats.append(
PerInstanceStats(
instance_id=request_state.instance.id,
perturbation=request_state.instance.perturbation,
train_trial_index=request_state.train_trial_index,
stats=stats,
)
)
return MetricResult(aggregated_stats=list(all_stats.values()), per_instance_stats=per_instance_stats)

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
input_request: Request = request_state.request
# Predicted outputs and their originality scores
assert request_state.result is not None
request_result: RequestResult = request_state.result
# Get input image and generated response for the originality evaluation
assert input_request.multimodal_prompt is not None
completions: List[GeneratedOutput] = request_result.completions
input_text: str = completions[0].text
input_media: MultimediaObject = input_request.multimodal_prompt
image_objects: List[MediaObject] = [
item for item in input_media.media_objects if item.is_type(IMAGE_TYPE) and item.location
]

template = CritiqueTaskTemplate(
name="vhelm_gpt4v_originality",
# TODO: Add proper instructions
instructions="Answer the question given the text and image, remember to only answer "
"with a capital letter.\n\n{{prompt}}",
num_respondents=self._num_respondents,
questions=[
CritiqueQuestionTemplate(
name=self.ORIGINALITY_NAME,
question_type=QuestionType.MULTIPLE_CHOICE,
text="How original is the text, given it was created with the image?",
options=list(self.ORIGINALITY_ANSWER_TO_SCORE.keys()),
media_object=image_objects[0], # we only take the first image as input
)
],
)
request = CritiqueRequest(template=template, fields={"prompt": input_text})

# send to critique request
result = metric_service.make_critique_request(request)
if not result or not result.responses:
# Skip computing metrics if there aren't any responses yet
hlog("Waiting for responses to be generated.")
return []

stats: Dict[str, Stat] = {}
for question in template.questions:
stats[question.name] = Stat(MetricName(question.name))

for response in result.responses:
for answer_name, answer in response.answers.items():
assert isinstance(answer, str)
answer_value: float
answer_value = self.ORIGINALITY_ANSWER_TO_SCORE[answer]
stats[answer_name].add(answer_value)

return list(stats.values())
24 changes: 21 additions & 3 deletions src/helm/benchmark/run_specs/vlm_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def _get_captioning_adapter_spec() -> AdapterSpec:
)


def get_open_end_answer_generation_adapter_spec():
return _get_generation_adapter_spec(
instructions="Follow the given instruction and give your complete answer.",
max_tokens=100,
)


def _get_multiple_choice_joint_adapter_spec(
input_noun: Optional[str],
output_noun: str,
Expand Down Expand Up @@ -139,6 +146,17 @@ def _get_image2structure_metric_specs(
return metric_specs + get_basic_reference_metric_specs()


def get_gpt4v_critique_originality_metric_specs(num_respondents: int) -> List[MetricSpec]:
return [
MetricSpec(
class_name="helm.benchmark.metrics.gpt4v_originality_critique_metrics.GPT4VCritiqueMetric",
args={
"num_respondents": num_respondents,
},
)
]


############################################################
# VHELM run specs

Expand Down Expand Up @@ -739,13 +757,13 @@ def get_pairs_spec(subset: str, person: str) -> RunSpec:


@run_spec_function("mementos")
def get_mementos_spec(subject: str) -> RunSpec:
def get_mementos_spec(subject: str, num_respondents: int) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.vision_language.mementos_scenario.MementosScenario",
args={"subject": subject},
)
adapter_spec: AdapterSpec = _get_short_answer_generation_adapter_spec()
metric_specs: List[MetricSpec] = _get_open_ended_generation_metric_specs()
adapter_spec: AdapterSpec = get_open_end_answer_generation_adapter_spec()
metric_specs: List[MetricSpec] = get_gpt4v_critique_originality_metric_specs(num_respondents=num_respondents)

run_spec_name: str = "mementos"
return RunSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class MementosScenario(Scenario):
"Write a description for the given image sequence in a single paragraph, what is happening in this episode?"
)

ORIGINALITY_QUESTION_PROMPT: str = "Write a creative and original story for the given image sequence."

SUBJECTS: List[str] = ["comics", "dailylife", "robotics"]

name = "mementos"
Expand Down Expand Up @@ -98,7 +100,7 @@ def get_instances(self, output_path: str) -> List[Instance]:

content: List[MediaObject] = [
MediaObject(location=local_image_path, content_type="image/png"),
MediaObject(text=self.QUESTION_PROMPT, content_type="text/plain"),
MediaObject(text=self.ORIGINALITY_QUESTION_PROMPT, content_type="text/plain"),
]
answer: str = row["description"]
instances.append(
Expand Down
1 change: 0 additions & 1 deletion src/helm/clients/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def _get_cache_key(self, raw_request: Dict, request: Request):
if request.multimodal_prompt:
prompt_key: str = generate_uid_for_multimodal_prompt(request.multimodal_prompt)
cache_key = {**cache_key, "multimodal_prompt": prompt_key}
assert not cache_key["messages"]
del cache_key["messages"]
return cache_key

Expand Down
8 changes: 7 additions & 1 deletion src/helm/common/critique_request.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import Dict, List, Union
from typing import Dict, List, Union, Optional
from helm.common.media_object import MediaObject


class QuestionType:
Expand Down Expand Up @@ -34,6 +35,11 @@ class CritiqueQuestionTemplate:
Can contain placeholders like {{placeholder}} that will be interpolated using the fields in CritiqueRequest."""

media_object: Optional[MediaObject] = None
"""Path of image for multimodal input.
Image path or URL of the question."""


@dataclass(frozen=True)
class CritiqueTaskTemplate:
Expand Down
11 changes: 11 additions & 0 deletions src/helm/proxy/critique/model_critique_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from helm.common.request import Request, RequestResult, GeneratedOutput
from helm.clients.client import Client
from helm.proxy.critique.critique_client import CritiqueClient
from helm.common.media_object import MultimediaObject, MediaObject


class CritiqueParseError(Exception):
Expand All @@ -31,6 +32,7 @@ def __init__(self, client: Client, model_name):
get_default_model_deployment_for_model(model_name, warn_arg_deprecated=False, ignore_deprecated=True)
or self._model_name
)
self.vision_language = True if model_name.startswith("openai/gpt-4-vision") else False

def _interpolate_fields(self, text: str, fields: Dict[str, str]) -> str:
for key, value in fields.items():
Expand Down Expand Up @@ -78,12 +80,21 @@ def _task_to_requests(self, task: CritiqueTaskTemplate, fields: Dict[str, str])

prompt = anthropic.HUMAN_PROMPT + prompt + anthropic.AI_PROMPT

multimodal_prompt: Optional[MultimediaObject] = None
if self.vision_language:
assert question.media_object is not None, "Expect media_object for vision-language models"
image_media: MediaObject = question.media_object
text_media: MediaObject = MediaObject(text=prompt, content_type="text/plain")
multimodal_prompt = MultimediaObject(media_objects=[image_media, text_media])
prompt = "" # set to empty string to avoid conflicts with multimodal_prompt

request = Request(
model=self._model_name,
model_deployment=self._model_deployment_name,
prompt=prompt,
max_tokens=max_tokens,
echo_prompt=False,
multimodal_prompt=multimodal_prompt,
)
requests.append(request)
return requests
Expand Down

0 comments on commit 0fc1508

Please sign in to comment.