Skip to content

Commit

Permalink
[Model] Expose dynamic_image_size as mm_processor_kwargs for Intern…
Browse files Browse the repository at this point in the history
…VL2 models (#10518)

Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py authored Nov 21, 2024
1 parent 8a93a59 commit d5ec121
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""Tests for InternVL's multimodal preprocessing kwargs."""
from typing import Callable, Optional

import pytest
from transformers import AutoTokenizer

from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalRegistry

from .....conftest import _ImageAssets
from ....utils import build_model_context

models = ["OpenGVLab/InternVL2-2B"]


# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_internvl():
from vllm.model_executor.models.internvl import InternVLInputPipeline

pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
return pipeline.input_processor


@pytest.fixture()
def dummy_data_for_internvl():
from vllm.model_executor.models.internvl import InternVLInputPipeline

pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
return pipeline.dummy_data


@pytest.fixture()
def get_max_internvl_image_tokens():
from vllm.model_executor.models.internvl import (
get_max_internvl_image_tokens)
return get_max_internvl_image_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_input_mapper_override(
model: str,
image_assets: _ImageAssets,
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
):
mm_processor_kwargs = {
"max_dynamic_patch": max_dynamic_patch,
}
if dynamic_image_size is not None:
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size

expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1

ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)

mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)

image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)
assert vllm_result["pixel_values"].size(1) == expected_num_patches


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_max_tokens_override(
get_max_internvl_image_tokens: Callable,
model: str,
max_dynamic_patch: Optional[int],
dynamic_image_size: Optional[bool],
):
"""Ensure get_max_internvl_image_tokens handles mm_processor_kwargs."""
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

if max_dynamic_patch is None:
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
expected_max_tokens = 256 * expected_num_patches

actual_max_tokens = get_max_internvl_image_tokens(
ctx=InputContext(ctx.model_config),
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
assert expected_max_tokens == actual_max_tokens


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_dummy_data_override(
dummy_data_for_internvl: Callable,
model: str,
num_imgs: int,
max_dynamic_patch: Optional[int],
dynamic_image_size: Optional[bool],
):
"""Ensure dummy_data_for_internvl handles kwargs properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the dummy data func.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)

if max_dynamic_patch is None:
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
expected_max_tokens = 256 * expected_num_patches

dummy_data = dummy_data_for_internvl(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
sequence_data = dummy_data.seq_data

tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
add_special_tokens=False)[0]

# Ensure we have the right number of placeholders per size
img_tok_count = sequence_data.get_token_ids().count(image_token_id)
assert img_tok_count == expected_max_tokens * num_imgs


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_input_processor_override(
input_processor_for_internvl: Callable,
image_assets: _ImageAssets,
model: str,
num_imgs: int,
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
):
"""Ensure input_processor_for_internvl handles kwargs properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1

ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
expected_toks_per_img = 256 * expected_num_patches

# Build the image str / prompt based on the number of images we pass
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
placeholders = "<image>" if num_imgs == 1 else "\n".join(
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
prompt = placeholders
images = [image_assets[0].pil_image.resize((448 * 2, 448 * 2))] * num_imgs

inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})

processed_inputs = input_processor_for_internvl(
ctx,
inputs,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)

# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
add_special_tokens=False)[0]
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
assert img_tok_count == expected_toks_per_img * num_imgs
63 changes: 49 additions & 14 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,15 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
return blocks, target_width, target_height


def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None):
def calculate_num_blocks_wrapper(
hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
if dynamic_image_size is None:
dynamic_image_size = hf_config.dynamic_image_size

max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
min_num = hf_config.min_dynamic_patch
Expand Down Expand Up @@ -183,10 +190,17 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
return pixel_values


def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None):
def image_to_pixel_values_wrapper(
hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
if dynamic_image_size is None:
dynamic_image_size = hf_config.dynamic_image_size

max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
Expand All @@ -207,11 +221,17 @@ def get_internvl_num_patches(hf_config: PretrainedConfig):
(downsample_ratio**2))


def get_max_internvl_image_tokens(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
def get_max_internvl_image_tokens(
ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
hf_config = ctx.get_hf_config()
if dynamic_image_size is None:
dynamic_image_size = hf_config.dynamic_image_size

max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
Expand All @@ -222,12 +242,18 @@ def get_max_internvl_image_tokens(ctx: InputContext,
return num_patches * max_dynamic_patch


def get_max_internvl_image_size(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
def get_max_internvl_image_size(
ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
hf_config = ctx.get_hf_config()
image_size = hf_config.vision_config.image_size
if dynamic_image_size is None:
dynamic_image_size = hf_config.dynamic_image_size

max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
Expand Down Expand Up @@ -281,6 +307,7 @@ def input_processor(
inputs: DecoderOnlyInputs,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
Expand All @@ -292,7 +319,7 @@ def input_processor(
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
num_blocks_calculator = calculate_num_blocks_wrapper(
hf_config, max_dynamic_patch)
hf_config, max_dynamic_patch, dynamic_image_size)
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = num_blocks_calculator(width, height)
Expand Down Expand Up @@ -332,11 +359,12 @@ def input_mapper(
data: object,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
hf_config = ctx.get_hf_config()

image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
hf_config, max_dynamic_patch, dynamic_image_size)
if isinstance(data, Image.Image):
data = image_pixel_values_mapper(data)
# Add an N dimension for number of images per prompt (currently 1).
Expand Down Expand Up @@ -366,13 +394,17 @@ def dummy_data(
mm_counts: Mapping[str, int],
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
num_images = mm_counts["image"]

hf_config = ctx.get_hf_config()

image_feature_size = get_max_internvl_image_tokens(
ctx, max_dynamic_patch=max_dynamic_patch)
ctx,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
Expand All @@ -388,7 +420,10 @@ def dummy_data(
)

max_image_width, max_image_height = get_max_internvl_image_size(
ctx, max_dynamic_patch=max_dynamic_patch)
ctx,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)

mm_data = dummy_image_for_clip(
hf_config.vision_config,
Expand Down

0 comments on commit d5ec121

Please sign in to comment.