diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f59c233883ce2..ab9a74767335f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,6 +2,7 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" +import math from functools import cached_property from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -46,14 +47,14 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: torch.Tensor + data: NestedTensors """Shape: `(batch_size, num_audios, 80, M)`""" - lens: torch.Tensor + lens: NestedTensors """ Length of the audio frames. Used for attention mask in WhisperEncoder. Shape: `(batch_size)` """ - token_len: torch.Tensor + token_len: NestedTensors """ Length of the audio tokens. Used for flattening the audio features. Shape: `(batch_size)` @@ -110,7 +111,11 @@ def get_mm_max_tokens_per_item( seq_len: int, mm_counts: Mapping[str, int], ) -> Mapping[str, int]: - return {} + feature_extractor = self.get_feature_extractor() + max_audio_tokens = math.ceil(feature_extractor.chunk_length * + _AUDIO_TOKENS_PER_SECOND) + + return {"audio": max_audio_tokens} class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] @@ -422,6 +427,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + # Due to the batching of audio chunks, the preprocessor cache cannot + # do the right thing so disable it. + vllm_config.model_config.disable_mm_preprocessor_cache = True multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multi_modal_config = multimodal_config @@ -516,10 +524,24 @@ def _process_audio_input( if audio_input["type"] == "audio_embeds": return audio_input["data"] - # remove unneeded extra dimension added to all elements of mm_kwargs - audio_features = flatten_bn(audio_input["data"]) - audio_lens = flatten_bn(audio_input["lens"]) - audio_token_len = flatten_bn(audio_input["token_len"]) + audio_features = audio_input["data"] + if isinstance(audio_features, list): + max_len = max(x.shape[-1] for x in audio_features) + # Pad and concatenate: + # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] + audio_features = torch.cat( + [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_features]) + else: + # Flatten [B, N, 80, M] -> [B * N, 80, M] + audio_features = flatten_bn(audio_features) + + if isinstance(audio_input['lens'], list): + # [B1, B2] -> [B1+B2] + audio_lens = torch.cat(audio_input['lens']) + audio_token_len = torch.cat(audio_input['token_len']) + else: + audio_lens = flatten_bn(audio_input['lens']) + audio_token_len = flatten_bn(audio_input['token_len']) embeddings = self._audio_features_to_embeddings( audio_features, audio_lens)