Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vision] Support Phi-3.5-vision, the first VLM in WebLLM #563

Merged
merged 7 commits into from
Sep 23, 2024

Conversation

CharlieFRuan
Copy link
Contributor

@CharlieFRuan CharlieFRuan commented Sep 18, 2024

This PR supports the first Vision Language Model, Phi-3.5-vision. For a full example, see examples/vision-model. Overall usage follows OpenAI API and is shown below. We add Phi-3.5-vision-instruct-q4f16_1-MLC and Phi-3.5-vision-instruct-q4f16_1-MLC to prebuilt model list.

  const messages: webllm.ChatCompletionMessageParam[] = [
    {
      role: "user",
      content: [
        { type: "text", text: "List the items in the image concisely." },
        {
          type: "image_url",
          image_url: { url: "https://www.ilankelman.org/sunset.jpg",},
        },
      ],
    },
  ];
  const request0: webllm.ChatCompletionRequest = {
    stream: false, // can be streaming, same behavior
    messages: messages,
  };
  const reply0 = await engine.chat.completions.create(request0);
  console.log(reply0);

Implementation details

To support vision model, we need various internal changes:

  • Generalize Conversation to non-string message:
    • Conversation.messages can not only be a string, but also Array<ChatCompletionContentPart>, the exact two types of ChatCompletionUserMessageParam.content
    • Thus getPromptArray(), which converts conversation history to a list of messages to prefill, returns Array<string | Array<string | ImageURL>> rather than just Array<string>
  • Implement getChunkedPrefillInputData()
    • Prior to this PR, we implement chunking naively because there are only token inputs. But with image inputs, a single image cannot be chunked when embedding.
    • Therefore, we first update getInputData() to work with ImageURLs
    • Then use getChunkedPrefillInputData() to chunk the output of getInputData(), where each chunk will be embed and forwarded, then the next chunk is processed
    • getChunkedPrefillInputData() is tested thoroughly in unit test
  • As a result, we need to separate embedding and forwarding, so we:
    • Replace forward() with embedAndForward(), which takes in a chunk, embeds it, and prefill / decode
    • Impelment getImageEmbeddings() and getTokensEmbeddings()
    • Implement concatEmbedding() in TVMjs to concatenate tokens and image embeddings: [WASM] Implement concat embeddings apache/tvm#17404

Tested

Since we changed logics in crucial methods like Conversations.getPromptArray(), LLMChatPipeline's getInputData(), getChunkedPrefillInputData(), forward(), we need to test thoroughly for both vision and non-vision usecases. The following are tested E2E:

  • Examples: vision-model, simple-chat-ts, logit processor, get_started, multi-round
  • Prefilling multiple messages
  • unit test for Conversations and getChunkedPrefillInputData()
  • Prefilling 300 tokens with 32 prefill chunk size
  • Did not observe significant performance change in prefill and decode toks/s for non-image models

TODO

  • Currently, we hardcode IMAGE_EMBED_SIZE to 1921. That is, all images' embeddings are (1921, hidden_size). We also hardcode num_crops to be 16 in the kernel. We should expose num_crops in the future and generalize IMAGE_EMBED_SIZE.
  • Besides, as noted in [WASM] Add phi3.5-vision to webllm binary-mlc-llm-libs#140, the kernel embed_image should have input of type uint8, but the current workaround is using uint32 input since there is no Array<u8> in WGSL. This should be fixed.
    • Naively using uint8 in TIR kernel results in error in runtime:
      • @group(0) @binding(0) var<storage, read_write> T_transpose : array<u8> --> error : unresolved type 'u8'
  • Consider allocating a static embeddings memory where we keep writing to the same piece

Related PRs

- Enable content being an array, which can have image_url
- Introduce ModelType.VLM so that only VLM can handle non-string message content
- Thus pass in loadedModelType to postInitCheck, hence add loadedModelIdToModelType in Engine
- Change unit tests correspondingly
- getPromptArray element can now be an array itself
- Conversation.messages message can either be a string or an array of content parts
- Update compareConversationObject
- Next step is to update llm_chat.ts getInputTokens
- Implement getInputData to replace getInputTokens
  - Instead of returning a list of tokens, return a list of mixture of number[] and imageUrl
- Implement getChunkedPrefillInputData that transforms output of getInputData into chunks, thoroughly tested
- Replace forward with embedAndForward, which takes in a chunk, embeds it, and forward
- Within embedAndForward, we first embed all components in the chunk
- Note chunking is taken care of in getChunkedPrefillInputData, so embedAndForward has data length less than prefill chunk size
- TODOs: implement getImageEmbeddings, concatenation of embeddings, E2E test with image input
- Tested:
  - simple chat, logit processor, get_started, multi-round example
  - prefill multimple messages
  - 32 prefill chunk size, prefill 300 tokens
- Impelment getImageEmbedding that loads into ImageData and call embed_image kernel
- Use TVM global function concatEmbeddings to combine text and image embeddings
- Imlpement helper function that loads ImageData from url, either http or base64
@CharlieFRuan
Copy link
Contributor Author

Tested with examples/vision-model, prefilling two images gives prefill speed 190 tokens/s, decode speed 30 tokens/s

image

@CharlieFRuan CharlieFRuan marked this pull request as ready for review September 23, 2024 01:25
@CharlieFRuan CharlieFRuan changed the title [WIP][Vision] Support Phi-3.5-vision [Vision] Support Phi-3.5-vision, the first VLM in WebLLM Sep 23, 2024
@CharlieFRuan CharlieFRuan merged commit 9c0aec4 into mlc-ai:main Sep 23, 2024
1 check passed
CharlieFRuan added a commit that referenced this pull request Sep 23, 2024
### Changes
- The only change is the support of Phi-3.5-vision:
  - #563
- Added `Phi-3.5-vision-instruct-q4f16_1-MLC` and
`Phi-3.5-vision-instruct-q4f32_1-MLC` to prebuilt model list
- See `examples/vision-model` on how to use vision language model in
WebLLM

### TVMjs
- Compiled at
apache/tvm@931efc7
  - Cherry-picked apache/tvm#17404 on top
- Note this does not require us to recompile non-vision models because
text-only inputs will not need embeddings concatenation
- WASMs: still the same `v0_2_48` WASMs
@atlury
Copy link

atlury commented Sep 23, 2024

excellent work @CharlieFRuan works really nicely, please continue the same good work and include Qwen2-vl also. 👍 Eagerly waiting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants