Skip to content

Commit

Permalink
fix collation
Browse files Browse the repository at this point in the history
Signed-off-by: Farzad Abdolhosseini <[email protected]>
  • Loading branch information
farzadab committed Feb 21, 2025
1 parent c7e0329 commit 0c5363e
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import math
from functools import cached_property
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
Expand Down Expand Up @@ -46,14 +47,14 @@

class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: torch.Tensor
data: NestedTensors
"""Shape: `(batch_size, num_audios, 80, M)`"""
lens: torch.Tensor
lens: NestedTensors
"""
Length of the audio frames. Used for attention mask in WhisperEncoder.
Shape: `(batch_size)`
"""
token_len: torch.Tensor
token_len: NestedTensors
"""
Length of the audio tokens. Used for flattening the audio features.
Shape: `(batch_size)`
Expand Down Expand Up @@ -110,7 +111,11 @@ def get_mm_max_tokens_per_item(
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {}
feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)

return {"audio": max_audio_tokens}


class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
Expand Down Expand Up @@ -422,6 +427,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
# Due to the batching of audio chunks, the preprocessor cache cannot
# do the right thing so disable it.
vllm_config.model_config.disable_mm_preprocessor_cache = True
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multi_modal_config = multimodal_config
Expand Down Expand Up @@ -516,10 +524,24 @@ def _process_audio_input(
if audio_input["type"] == "audio_embeds":
return audio_input["data"]

# remove unneeded extra dimension added to all elements of mm_kwargs
audio_features = flatten_bn(audio_input["data"])
audio_lens = flatten_bn(audio_input["lens"])
audio_token_len = flatten_bn(audio_input["token_len"])
audio_features = audio_input["data"]
if isinstance(audio_features, list):
max_len = max(x.shape[-1] for x in audio_features)
# Pad and concatenate:
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
audio_features = torch.cat(
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_features])
else:
# Flatten [B, N, 80, M] -> [B * N, 80, M]
audio_features = flatten_bn(audio_features)

if isinstance(audio_input['lens'], list):
# [B1, B2] -> [B1+B2]
audio_lens = torch.cat(audio_input['lens'])
audio_token_len = torch.cat(audio_input['token_len'])
else:
audio_lens = flatten_bn(audio_input['lens'])
audio_token_len = flatten_bn(audio_input['token_len'])

embeddings = self._audio_features_to_embeddings(
audio_features, audio_lens)
Expand Down

0 comments on commit 0c5363e

Please sign in to comment.