Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Extend Ultravox to accept audio longer than 30s #13631

Open
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

farzadab
Copy link
Contributor

@farzadab farzadab commented Feb 20, 2025

Currently the Ultravox model input is capped to 30 seconds and extra audio is truncated (AFAIK). Also each sample is fed to Whisper individually (without being batched).

This PR allows using longer audio by chunking them first, using Whisper encoder in batch mode, and then concatenates them.

TODO:

  • processors on HF still need to be updated in tandem with this PR.
  • run evaluations with the updated model to verify the changes.

@farzadab farzadab marked this pull request as draft February 20, 2025 21:29
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mgoin
Copy link
Member

mgoin commented Feb 21, 2025

FYI @NickLucche for the usage of whisper

@NickLucche
Copy link
Contributor

Thanks for the contrib!
What is the chunking logic for tiling the audio? Feel free to link the hf processor PR.

@farzadab
Copy link
Contributor Author

farzadab commented Feb 21, 2025

re @NickLucche: Here's the processor link: https://huggingface.co/fixie-ai/ultravox-v0_3-llama-3_2-1b/blob/main/ultravox_processing.py#L209

The logic: for each audio, split to 30 second chunks (but do not pad the last item to 30s, which is the same as before).
Then we flatten and batch everything up and run Whisper as if they were separate audios. We use audio_lens to compute an attention_mask for the last chunk per audio. The final embeddings are then concatenated.

There are other ways we could've done this, but it matches what we do on the Ultravox side for both some fine-tuning that we do and evals. If we end up updating those I'll update VLLM as well.

Also, note that since we don't pad the last chunk, and since in most cases we have smaller than 30s audio, the number of frames do not match across samples. I didn't see a collator anywhere that I could update. I'm suspecting that I'll have to update _process_audio_input further to handle that. Updated _process_audio_input.

Signed-off-by: Farzad Abdolhosseini <[email protected]>
@NickLucche
Copy link
Contributor

Ok I see then that's a naive chunking where you don't account for splitting mid-word nor you have any overlap and/or prompt from previous chunk.

This case seems much easier to handle vllm-side, given changes are already in hf. Let's just make sure the batched whisper forward is accounted for by the initial profiler run to avoid oom.

Signed-off-by: Farzad Abdolhosseini <[email protected]>
Signed-off-by: Farzad Abdolhosseini <[email protected]>
Signed-off-by: Farzad Abdolhosseini <[email protected]>
Signed-off-by: Farzad Abdolhosseini <[email protected]>
Signed-off-by: Farzad Abdolhosseini <[email protected]>
Signed-off-by: Farzad Abdolhosseini <[email protected]>
Signed-off-by: Farzad Abdolhosseini <[email protected]>
Signed-off-by: Farzad Abdolhosseini <[email protected]>
@farzadab farzadab marked this pull request as ready for review February 26, 2025 00:53
Signed-off-by: Farzad Abdolhosseini <[email protected]>
@farzadab
Copy link
Contributor Author

farzadab commented Feb 28, 2025

Thanks for the comments. This PR has been ready to review.

For reference, I can confirm that the evals have improved.

before (8B model):

                                eval    subset model samples  score tokens
0                 audio-bigbench-30s         -  vllm    None  66.67   None
1             audio-bigbench-nolimit         -  vllm    None  62.60   None
2       audio-translate-covost-en_de     en_de  vllm    None  28.60   None
3       audiobench-dream-tts-mcq-30s         -  vllm    None  85.41   None
4   audiobench-dream-tts-mcq-nolimit         -  vllm    None  76.79   None

after (8B model):

                                eval    subset model samples  score tokens
0                 audio-bigbench-30s         -  vllm    None  67.42   None
1             audio-bigbench-nolimit         -  vllm    None  65.10   None
2       audio-translate-covost-en_de     en_de  vllm    None  28.66   None
3       audiobench-dream-tts-mcq-30s         -  vllm    None  84.92   None
4   audiobench-dream-tts-mcq-nolimit         -  vllm    None  84.89   None

Rows 0, 2, and 3 are there as a sanity check. Difference of less than 1 point is usually not significant (specially on model-as-judge evals). 30s means the subset of samples that are under 30 seconds long. nolimit means the full set and these are the sets for which we see 3 and 8 points of improvement.

A similar trend is seen on 70B which reaches 90.30 on audio-bigbench-nolimit compared to 82.9 that we had reported before.

Copy link
Contributor

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work here. Surprised to see how big of a leap can a simple chunking strategy achieve!

@farzadab
Copy link
Contributor Author

Thanks!

Surprised to see how big of a leap can a simple chunking strategy achieve!

Just to clarify, the difference in metrics is not because of a "better" chunking strategy. It's just that, before this we used to throw away any audio past 30 seconds. Any chunking strategy is probably better than no strategy 😅

@DarkLight1337
Copy link
Member

Can you update tests/models/decoder_only/audio_language/test_ultravox.py back to using v0.5 as well?

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! Let's see if the tests pass.

@DarkLight1337 DarkLight1337 changed the title [WIP][Model] Extend Ultravox to accept audio longer than 30s [Model] Extend Ultravox to accept audio longer than 30s Mar 1, 2025
@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 1, 2025
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello! @farzadab I left a few comments regarding the CI failure. Thanks for your contribution!

curr = curr.get(key, {})
curr.pop(keys[-1], None)

return _convert_tensors_to_list(result)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without doing _convert_tensors_to_list, the equals comparison kept failing even on the same inputs:

E             Differing items:
E             {'mm_kwargs': {'audio_num_chunks': tensor([1, 1, 1]), 'audio_lens': tensor([4, 5, 6]), 'audio_token_len': tensor([1, 1, 1], dtype=torch.int32)}} !=
E             {'mm_kwargs': {'audio_num_chunks': tensor([1, 1, 1]), 'audio_lens': tensor([4, 5, 6]), 'audio_token_len': tensor([1, 1, 1], dtype=torch.int32)}}

Copy link
Member

@DarkLight1337 DarkLight1337 Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiModalKwargs needs to be handled separately because it has a custom equality check.

Copy link
Contributor Author

@farzadab farzadab Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh gawd, that's why. I kept banging my head on the wall for so long on this issue 😢

The issue here is that _items_by_modality still keeps a version of audio_features.

@mergify mergify bot added the multi-modality Related to multi-modality (#4194) label Mar 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants