Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] handle alignment of arguments in convert_sparse_cross_attention_mask_to_dense #12347

Merged

Conversation

tjohnson31415
Copy link
Contributor

@tjohnson31415 tjohnson31415 commented Jan 23, 2025

Reproducing the bug reqiures a batch with a text-only request and a request with an image. With the OpenAI server, this can happen when under load, but it is easier to reproduce in offline mode:

from vllm import LLM, SamplingParams
from vllm.multimodal.utils import fetch_image

image_url = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
image_data = fetch_image(image_url)

model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"

llm = LLM(
    model=model_name,
    max_model_len=4096,
    max_num_seqs=2,
    enforce_eager=True,
)

sampling_params = SamplingParams(
    temperature=0.0,
    max_tokens=128
)

outputs = llm.generate(
    prompts=[
        {
            "prompt": "What is the capital of Spain?",
        },
        {
            "prompt": "Analyze this image <|image|>. What do you see?",
            "multi_modal_data": {
                "image": image_data,
            },
        },
    ],
    sampling_params=sampling_params
)

This would have been a crash before #11939 but now results in an AssertionError. The reason is that the num_tiles: List[List[int]] passed in to convert_sparse_cross_attention_mask_to_dense would not have any entry in the list for a sequence that is text only. The inputs to the function would be like:

sparse_mask = [[], [[5, -1]]]
num_tiles = [[4]]
lengths = [7, 12]

The fix in this PR is to skip [] entries in sparse_mask for text-only requests.

A better fix may be to have num_tiles created with an entry for each sequence, but I didn't see where to make that change.

Potential fix for #10648

tjohnson31415 and others added 2 commits January 22, 2025 23:29
Co-authored-by: Wallas Santos <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
…_mask_to_dense

Without the alignemnt the an AssertionError is raised if a text-only
sequence precedes one with an image.

Co-authored-by: Wallas Santos <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@heheda12345
Copy link
Collaborator

LGTM! Thank you for the bug fix.
@DarkLight1337 Does @large_gpu_test(min_gb=48) means this new test will be skipped during CI? If that is true, I think we should implement some test for get_cross_attention_mask (and also get_cross_attention_states & get_full_text_row_masked_out_mask if possible)

@DarkLight1337
Copy link
Member

@DarkLight1337 Does @large_gpu_test(min_gb=48) means this new test will be skipped during CI?

Yes, that is correct.

@heheda12345
Copy link
Collaborator

@tjohnson31415 Can you change e2e test to a test for get_cross_attention_mask? (and also get_cross_attention_states & get_full_text_row_masked_out_mask if possible)

@wallashss
Copy link
Contributor

Hey @heheda12345 , I added some tests that would be enough to prevent regression on this issue. Please, see if you agree.

Thanks!

cc @tjohnson31415

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the bug fix and new tests!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) January 29, 2025 04:17
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 29, 2025
@DarkLight1337 DarkLight1337 merged commit 036ca94 into vllm-project:main Jan 29, 2025
61 checks passed
rasmith pushed a commit to rasmith/vllm that referenced this pull request Jan 30, 2025
…ion_mask_to_dense (vllm-project#12347)

Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Co-authored-by: Wallas Santos <[email protected]>
Isotr0py pushed a commit to Isotr0py/vllm that referenced this pull request Feb 2, 2025
…ion_mask_to_dense (vllm-project#12347)

Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Co-authored-by: Wallas Santos <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Feb 7, 2025
…ion_mask_to_dense (vllm-project#12347)

Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Wallas Santos <[email protected]>
Co-authored-by: Wallas Santos <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants