Skip to content

Commit

Permalink
[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)
Browse files Browse the repository at this point in the history
  • Loading branch information
lk-chen authored Mar 4, 2025
1 parent c8525f0 commit b3cf368
Show file tree
Hide file tree
Showing 22 changed files with 249 additions and 150 deletions.
294 changes: 176 additions & 118 deletions examples/offline_inference/vision_language.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,9 @@ def _process_image_input(

return self.multi_modal_projector(image_outputs, image_attn_mask)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,9 @@ def _process_image_input(self,

return self.language_projection(query_output)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,9 @@ def _parse_and_validate_image_input(
data=self._validate_pixel_values(pixel_values),
)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,9 @@ def _process_image_input(
return self._pixel_values_to_embedding(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)

def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,9 @@ def _process_image_input(
pixel_values = image_input["data"]
return self._encode_image(pixel_values)

def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -327,7 +327,9 @@ def _process_image_input(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,9 @@ def _process_image_input(

return self.transformer.vision(pixel_values)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = get_sampler()

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self.model._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
18 changes: 9 additions & 9 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Protocol, Type, Union, overload, runtime_checkable)

import torch
from torch import Tensor
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
Expand All @@ -15,12 +16,11 @@

if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.multimodal.inputs import NestedTensors # noqa: F401
from vllm.sequence import IntermediateTensors

logger = init_logger(__name__)

T = TypeVar("T", default="NestedTensors")
T = TypeVar("T", default=Union[list[Tensor], Tensor, tuple[Tensor, ...]])


@runtime_checkable
Expand All @@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""

def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
def get_multimodal_embeddings(self, **kwargs) -> T:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
Expand All @@ -59,18 +59,18 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
@overload
def get_input_embeddings(
self,
input_ids: torch.Tensor,
input_ids: Tensor,
multimodal_embeddings: Optional[T] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> torch.Tensor:
) -> Tensor:
...

@overload
def get_input_embeddings(
self,
input_ids: torch.Tensor,
input_ids: Tensor,
multimodal_embeddings: Optional[T] = None,
) -> torch.Tensor:
) -> Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
Expand Down Expand Up @@ -210,7 +210,7 @@ def forward(
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
) -> Union[Tensor, "IntermediateTensors"]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
Expand All @@ -237,7 +237,7 @@ def forward(
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
) -> Union[Tensor, "IntermediateTensors"]:
...


Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,9 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
else:
self.visual_token_mask = None

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,9 @@ def _process_image_input(self,
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,9 @@ def _process_image_input(
for i, patch_features_batch in enumerate(patch_embeddings)
]

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,9 @@ def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None:
return None
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves

from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsQuant)
Expand Down Expand Up @@ -1576,21 +1576,24 @@ def _get_mm_embeds(

return embeds_in_batch

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None

image_features = self._process_image_input(image_input)

return [
nested_embeds = [
self._get_mm_embeds(*args) for args in zip(
image_features,
image_input["feat_is_patch"],
image_input["num_crops"],
image_input["embed_is_patch"],
)
]
return flatten_2d_lists(nested_embeds)

def get_input_embeddings(
self,
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def _process_image_input(

return self.multi_modal_projector(image_features)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,9 @@ def _process_image_input(

return image_embeds

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def sampler(self):

return get_sampler()

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs)
if image_input is None:
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ def _process_audio_input(self,
return torch.split(masked_audio_features,
audio_output_lengths.flatten().tolist())

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,9 @@ def _process_image_input(self,

return self.transformer.visual(image_input["data"])

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,9 @@ def _process_audio_input(

return result

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,9 @@ def forward(
)
return decoder_outputs

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs)
Expand Down

0 comments on commit b3cf368

Please sign in to comment.