Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API][Engine] Support loading multiple models in a single engine #542

Merged
merged 5 commits into from
Aug 13, 2024

Conversation

CharlieFRuan
Copy link
Contributor

@CharlieFRuan CharlieFRuan commented Aug 12, 2024

This PR does not change any behavior of existingly supported API/workflow, but only introduces new behaviors when multiple models are loaded in a single engine. For users, please refer to examples/multi-models and "User-facing changes" below.

  const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k";
  const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k";
  const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
    [selectedModel1, selectedModel2],
  );

User-facing changes

  • When multiple models are loaded in the engine:
    • model field is required to be specified in ChatCompletionRequest, EmbeddingCreateParams, and CompletionCreateParams
    • The following engine APIs need to specify the target model: getMessage(), runtimeStatsText(), resetChat(), forwardTokensAndSample()
  • reload() / unload():
    • To load multiple models, call reload([model_id1, model_id2])
    • To load a single model, as before, call reload(model_id1)
    • unload() unloads all loaded models, or interrupt a current reload (as before)
  • CreateMLCEngine(), CreateWebWorkerEngine(), CreateServiceWorkerEngine() are all updated accordingly to allow string[] input for modelId

Internal code change

Before diving in, we first identify the following APIs in MLCEngineInterface that directly interact with pipeline, as already mentioned above: getMessage(), forwardTokensAndSample(), runtimeStatsText(), resetChat()

support.ts, utils.ts

  • getModelIdToUse(): given the loaded model ids and the requestedModel (can be undefined), reconcile the model to use. Added unit test
  • areChatOptionsListEqual(): called when the service worker decides whether to skip loading. Added unit test

src/openai_api_protocols

  • The only change is to remove model from the list of unsupported fields and update the documentation

types.ts

  • reload()'s modelId is now string | string[]. Update reload() and unload() documentation
  • For each of the user-facing APIs that directly interact with a pipeline, modelId is required when multiple models are loaded, as explained in "User-facing changes"

src/engine.ts -- main changes happen here

  • Introduce private fields this.loadedModelIdToPipeline and this.loadedModelIdToChatConfig to replace this.currentModelId, this.config, this.pipeline. Update reloadInternal() and unload() to reflect this.
  • reload(): more preprocessing due to possible string[]. For reload abort, a single abort stops all models loading
  • Add getLLMStates() to query the currently loaded pipeline, model id, and chatConfig
  • Abandon any use case of this.getPipeline(). Instead, each internal API (prefill(), decode(), _generate(), asyncGenerate()) all know which pipeline to execute on from their parameter, passed by the high-level APIs (embedding(), chatCompletion(), completion())
  • Only user-facings APIs that directly interact with the pipeline cannot pinpoint the pipeline in the parameter, hence resorting to getLLMStates(). These are: getMessage(), forwardTokensAndSample(), runtimeStatsText(), resetChat()

WebWorker changes: message.ts, extension_service_worker.ts, service_worker.ts, web_worker.ts

  • Due to the additional modelId? parameter for the user-facing APIs that directly interact with a pipeline (resetChat(), etc. as mentioned earlier):
    • We modify or add their params message in message.ts
    • Correspondingly, change the handling of these messages in WebWorkerMLCEngineHandler, and the sending of these messages in WebWorkerMLCEngine
  • To check out-of-sync of frontend/backend loaded models (fore more see [Worker] Move reload upon worker termination to WebWorker #533)
    • Update Handler.modelId and Handler.chatOpts to string[] and ChatOptions[]
    • For each OpenAI-API-initaiting params (e.g. ChatCompletionNonStreamingParams), also send string[] and chatOptions[]. That is, we send all the expected loaded models, despite this API only requiring one. As a result, we reload if the Handler.modelId and modelId in API's message param do not strictly match
  • Update reload(), ReloadParams, and reload-handling in Handler in ServiceWorker. To skip loading:
    • Before: this.modelId === params.modelId, Now: areArraysEqual(this.modelId, params.modelId)
    • Before: areChatOptionsEqual(...), Now: areChatOptionsListEqual(...)

Tested

  • Unit tests
  • chat.completions.create() and embeddings.create() for MLCEngine, web worker, and service worker
  • examples/multi-models; also in web worker and service worker settings
  • simple-chat-ts; test abortController() by removing await from reload() and asyncInitChat() in simple_chat.ts
  • WebLLMChat; tested terminating service worker manually and re-send chat completion request

@CharlieFRuan CharlieFRuan changed the title [API][Engine] Support loading multiple models in a single MLCEngineInterface [API][Engine] Support loading multiple models in a single engine Aug 12, 2024
- In reloadInternal, signal aborts loadings of all models
- Remove engine.stopped()
- Make chatOpts optional in OpenAI-API-initiating message params
- Add E2E unit tests for getModelIdToUse
@CharlieFRuan CharlieFRuan marked this pull request as ready for review August 13, 2024 08:10
@CharlieFRuan CharlieFRuan merged commit 78fed78 into mlc-ai:main Aug 13, 2024
1 check passed
CharlieFRuan added a commit that referenced this pull request Aug 13, 2024
### Changes
- #541
- #542
  - When single model loaded, no change in behavior
- When multiple models loaded, some APIs need to specify which model it
is targeting
  - For more, see PR description (the user-facing section)
  - Also see `examples/multi-models`

### TVMjs
Still compiled at
apache/tvm@1fcb620,
no change
CharlieFRuan added a commit that referenced this pull request Aug 13, 2024
This is a follow-up to #542.
Update `examples/multi-model` to use web worker, and to also show case
generating responses from two models concurrently from the same engine.
This is already supported for `MLCEngine` prior to this PR, but
`WebWorkerMLCEngine` needed a patch. Specifically:

- Prior to this PR, `WebWorkerMLCEngineHandler` maintains a single
`asyncGenreator`, assuming there is only one model.
- Now, to support concurrent streaming request, we replace
`this.asyncGenerator` with `this.loadedModelIdToAsyncGenerator`, which
maps from a model id to its dedicated `asyncGenerator`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant