Skip to content

Commit

Permalink
Adding difficulty as a parameter to Image2Structure (stanford-crfm#2660)
Browse files Browse the repository at this point in the history
  • Loading branch information
JosselinSomervilleRoberts authored and xuwangyin committed Jun 23, 2024
1 parent 1fff542 commit d2463b3
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 56 deletions.
39 changes: 34 additions & 5 deletions src/helm/benchmark/presentation/run_entries_image2structure.conf
Original file line number Diff line number Diff line change
@@ -1,20 +1,49 @@
# Conf file for Image2Structure
entries: [

# image2latex
# image2latex - all
{description: "image2latex:subset=equation,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=table,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=plot,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=algorithm,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=real,model=vlm", priority: 1, groups: ["image2latex"]}
# image2latex - easy
{description: "image2latex:subset=equation,difficulty=easy,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=table,difficulty=easy,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=plot,difficulty=easy,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=algorithm,difficulty=easy,model=vlm", priority: 1, groups: ["image2latex"]}
# image2latex - medium
{description: "image2latex:subset=equation,difficulty=medium,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=table,difficulty=medium,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=plot,difficulty=medium,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=algorithm,difficulty=medium,model=vlm", priority: 1, groups: ["image2latex"]}
# image2latex - hard
{description: "image2latex:subset=equation,difficulty=hard,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=table,difficulty=hard,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=plot,difficulty=hard,model=vlm", priority: 1, groups: ["image2latex"]}
{description: "image2latex:subset=algorithm,difficulty=hard,model=vlm", priority: 1, groups: ["image2latex"]}

# sheetmusic2lilypond
{description: "image2musicsheet:model=vlm", priority: 1, groups: ["image2musicsheet"]}
{description: "image2musicsheet:difficulty=easy,model=vlm", priority: 1, groups: ["image2musicsheet"]}
{description: "image2musicsheet:difficulty=medium,model=vlm", priority: 1, groups: ["image2musicsheet"]}
{description: "image2musicsheet:difficulty=hard,model=vlm", priority: 1, groups: ["image2musicsheet"]}

# webpages
# webpages - all
{description: "image2webpage:subset=css,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=html,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=javascript,model=vlm", priority: 1, groups: ["image2webpage"]}

# chart2csv
# {description: "chart2csv:model=vlm", priority: 1}
{description: "image2webpage:subset=real,model=vlm", priority: 1, groups: ["image2webpage"]}
# webpages - easy
{description: "image2webpage:subset=css,difficulty=easy,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=html,difficulty=easy,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=javascript,difficulty=easy,model=vlm", priority: 1, groups: ["image2webpage"]}
# webpages - medium
{description: "image2webpage:subset=css,difficulty=medium,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=html,difficulty=medium,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=javascript,difficulty=medium,model=vlm", priority: 1, groups: ["image2webpage"]}
# webpages - hard
{description: "image2webpage:subset=css,difficulty=hard,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=html,difficulty=hard,model=vlm", priority: 1, groups: ["image2webpage"]}
{description: "image2webpage:subset=javascript,difficulty=hard,model=vlm", priority: 1, groups: ["image2webpage"]}
]
103 changes: 62 additions & 41 deletions src/helm/benchmark/run_specs/vlm_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ADAPT_GENERATION_MULTIMODAL,
ADAPT_MULTIPLE_CHOICE_JOINT_MULTIMODAL,
)
from helm.benchmark.scenarios.vision_language.image2structure.image2structure_scenario import DIFFICULTY_ALL
from helm.benchmark.metrics.common_metric_specs import (
get_basic_reference_metric_specs,
get_exact_match_metric_specs,
Expand Down Expand Up @@ -421,10 +422,12 @@ def get_vqa_spec() -> RunSpec:


@run_spec_function("image2latex")
def get_image2latex_spec(subset: str, recompile_prompt: bool = False, args: Optional[Dict] = None) -> RunSpec:
def get_image2latex_spec(
subset: str, recompile_prompt: bool = False, difficulty: str = DIFFICULTY_ALL, args: Optional[Dict] = None
) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.vision_language.image2structure.latex_scenario.LatexScenario",
args={"subset": subset, "recompile_prompt": recompile_prompt},
args={"subset": subset, "recompile_prompt": recompile_prompt, "difficulty": difficulty},
)
adapter_spec: AdapterSpec = _get_generation_adapter_spec(
instructions="Just give a short answer without answering in a complete sentence.",
Expand All @@ -442,22 +445,31 @@ def get_image2latex_spec(subset: str, recompile_prompt: bool = False, args: Opti
)
]

run_spec_name: str = "image2latex"
run_spec_name: str = f"image2latex:subset={subset}"
groups: List[str] = ["image2latex", f"image2{subset}"]
if difficulty != DIFFICULTY_ALL:
run_spec_name += f":difficulty={difficulty}"
groups += [f"image2latex-{difficulty}"]
return RunSpec(
name=f"{run_spec_name}:subset={subset}",
name=run_spec_name,
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
groups=groups,
annotators=annotator_specs,
)


@run_spec_function("image2webpage")
def get_image2webpage_spec(subset: str, recompile_prompt: bool = False, args: Optional[Dict] = None) -> RunSpec:
def get_image2webpage_spec(
subset: str,
recompile_prompt: bool = False,
difficulty: str = DIFFICULTY_ALL,
args: Optional[Dict] = None,
) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.vision_language.image2structure.webpage_scenario.WebpageScenario",
args={"subset": subset, "recompile_prompt": recompile_prompt},
args={"subset": subset, "recompile_prompt": recompile_prompt, "difficulty": difficulty},
)
adapter_spec: AdapterSpec = _get_generation_adapter_spec(
instructions="Just give a short answer without answering in a complete sentence.",
Expand All @@ -475,50 +487,27 @@ def get_image2webpage_spec(subset: str, recompile_prompt: bool = False, args: Op
)
]

run_spec_name: str = "image2webpage"
run_spec_name: str = f"image2webpage:subset={subset}"
groups: List[str] = ["image2webpage", f"image2{subset}"]
if difficulty != DIFFICULTY_ALL:
run_spec_name += f":difficulty={difficulty}"
groups += [f"image2webpage-{difficulty}"]
return RunSpec(
name=f"{run_spec_name}:subset={subset}",
name=run_spec_name,
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
groups=groups,
annotators=annotator_specs,
)


@run_spec_function("math_vista")
def get_math_vista_spec(grade: str, question_type: str) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.vision_language.math_vista_scenario.MathVistaScenario",
args={"grade": grade, "question_type": question_type},
)

adapter_spec: AdapterSpec
if question_type == "free_form":
adapter_spec = _get_short_answer_generation_adapter_spec()
elif question_type == "multi_choice":
adapter_spec = _get_multiple_choice_joint_adapter_spec(
input_noun=None, output_noun="Answer", max_train_instances=0
)
else:
raise ValueError(f"Invalid question type: {question_type}")

metric_specs: List[MetricSpec] = get_exact_match_metric_specs()
run_spec_name: str = "math_vista"
return RunSpec(
name=f"{run_spec_name}:grade={grade},question_type={question_type}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
)


@run_spec_function("image2musicsheet")
def get_image2musicsheet_spec(args: Optional[Dict] = None) -> RunSpec:
def get_image2musicsheet_spec(difficulty: str = DIFFICULTY_ALL, args: Optional[Dict] = None) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.vision_language.image2structure.musicsheet_scenario.MusicSheetScenario",
args={"subset": "music", "recompile_prompt": False}, # There os only one subset for music sheets
# There os only one subset for music sheets
args={"subset": "music", "recompile_prompt": False, "difficulty": difficulty},
)
adapter_spec: AdapterSpec = _get_generation_adapter_spec(
instructions="Just give a short answer without answering in a complete sentence.",
Expand All @@ -537,16 +526,48 @@ def get_image2musicsheet_spec(args: Optional[Dict] = None) -> RunSpec:
]

run_spec_name: str = "image2musicsheet"
groups: List[str] = ["image2musicsheet"]
if difficulty != DIFFICULTY_ALL:
run_spec_name += f":difficulty={difficulty}"
groups += [f"image2musicsheet-{difficulty}"]
return RunSpec(
name=run_spec_name,
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
groups=groups,
annotators=annotator_specs,
)


@run_spec_function("math_vista")
def get_math_vista_spec(grade: str, question_type: str) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.vision_language.math_vista_scenario.MathVistaScenario",
args={"grade": grade, "question_type": question_type},
)

adapter_spec: AdapterSpec
if question_type == "free_form":
adapter_spec = _get_short_answer_generation_adapter_spec()
elif question_type == "multi_choice":
adapter_spec = _get_multiple_choice_joint_adapter_spec(
input_noun=None, output_noun="Answer", max_train_instances=0
)
else:
raise ValueError(f"Invalid question type: {question_type}")

metric_specs: List[MetricSpec] = get_exact_match_metric_specs()
run_spec_name: str = "math_vista"
return RunSpec(
name=f"{run_spec_name}:grade={grade},question_type={question_type}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
)


@run_spec_function("mmmu")
def get_mmmu_spec(subject: str, question_type: str) -> RunSpec:
scenario_spec = ScenarioSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from helm.common.hierarchical_logger import hlog

PROCESSED: str = "processed"
DIFFICULTY_ALL = "all"
DIFFICULTY_EASY = "easy"
DIFFICULTY_MEDIUM = "medium"
DIFFICULTY_HARD = "hard"


class Image2StructureScenario(Scenario):
Expand All @@ -38,13 +42,16 @@ class Image2StructureScenario(Scenario):
VALID_SPLIT: "validation",
}

def __init__(self, subset: str, recompile_prompt: bool = True, split: str = VALID_SPLIT):
def __init__(
self, subset: str, recompile_prompt: bool = True, split: str = VALID_SPLIT, difficulty: str = DIFFICULTY_ALL
):
super().__init__()
assert subset in self.SUBSETS, f"Invalid subset: {subset}"
self._subset: str = subset
self._recompile_prompt: bool = recompile_prompt
self._split: str = split
self._output_path: Optional[str] = None
self._difficulty: str = difficulty

def preprocess_row(self, row: Dict[str, Any], assets_path: str) -> Dict[str, Any]:
# By default, there are no assets
Expand Down Expand Up @@ -110,6 +117,10 @@ def get_instances(self, output_path: str) -> List[Instance]:
)
continue

# Filter by difficulty
if self._difficulty != DIFFICULTY_ALL and row["difficulty"] != self._difficulty:
continue

# Step 1: Preprocess the row
row = self.preprocess_row(row, assets_path)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from helm.benchmark.scenarios.scenario import VALID_SPLIT
from helm.benchmark.scenarios.vision_language.image2structure.utils_latex import (
latex_to_image,
strip_unnecessary_latex_parts,
Expand All @@ -14,9 +13,6 @@ class LatexScenario(Image2StructureScenario):
name = "image2latex"
description = "Evaluate multimodal models on Latex generation to recreate a provided image"

def __init__(self, subset: str, recompile_prompt: bool = True, split: str = VALID_SPLIT):
super().__init__(subset, recompile_prompt, split)

def compile_and_save(self, structure: str, assets_path: str, destination_path: str) -> str:
image, infos = latex_to_image(structure, assets_path=assets_path, crop=True)
image.save(destination_path)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from helm.benchmark.scenarios.scenario import VALID_SPLIT
from helm.benchmark.scenarios.vision_language.image2structure.image2structure_scenario import Image2StructureScenario


Expand All @@ -13,8 +12,5 @@ class MusicSheetScenario(Image2StructureScenario):
name = "image2musicsheet"
description = "Evaluate multimodal models on Lilypond generation to recreate a provided image"

def __init__(self, subset: str, recompile_prompt: bool = True, split: str = VALID_SPLIT):
super().__init__(subset, recompile_prompt, split)

def compile_and_save(self, structure: str, assets_path: str, destination_path: str) -> str:
raise Exception("Music sheets have no ground truth, compilation is not possible")
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from helm.benchmark.scenarios.vision_language.image2structure.image2structure_scenario import (
Image2StructureScenario,
PROCESSED,
DIFFICULTY_ALL,
)
from helm.benchmark.scenarios.vision_language.image2structure.webpage.jekyll_server import JekyllServer
from helm.benchmark.scenarios.vision_language.image2structure.webpage.driver import (
Expand Down Expand Up @@ -140,9 +141,10 @@ def __init__(
subset: str,
recompile_prompt: bool = True,
split: str = VALID_SPLIT,
difficulty: str = DIFFICULTY_ALL,
screenshot_options: ScreenshotOptions = ScreenshotOptions(),
):
super().__init__(subset, recompile_prompt, split)
super().__init__(subset, recompile_prompt, split, difficulty)
self._screenshot_options = screenshot_options
self._html2text = HTML2Text()
self._html2text.ignore_links = True
Expand Down

0 comments on commit d2463b3

Please sign in to comment.