diff --git a/examples/multi-models/src/main.ts b/examples/multi-models/src/main.ts new file mode 100644 index 00000000..242d98be --- /dev/null +++ b/examples/multi-models/src/main.ts @@ -0,0 +1,134 @@ +/** + * This example demonstrates loading multiple models in the same engine concurrently. + * sequentialGeneration() shows inference each model one at a time. + * parallelGeneration() shows inference both models at the same time. + * This example uses WebWorkerMLCEngine, but the same idea applies to MLCEngine and + * ServiceWorkerMLCEngine as well. + */ + +import * as webllm from "@mlc-ai/web-llm"; + +function setLabel(id: string, text: string) { + const label = document.getElementById(id); + if (label == null) { + throw Error("Cannot find label " + id); + } + label.innerText = text; +} + +const initProgressCallback = (report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); +}; + +// Prepare request for each model, same for both methods +const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k"; +const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k"; +const prompt1 = "Tell me about California in 3 short sentences."; +const prompt2 = "Tell me about New York City in 3 short sentences."; +setLabel("prompt-label-1", `(with model ${selectedModel1})\n` + prompt1); +setLabel("prompt-label-2", `(with model ${selectedModel2})\n` + prompt2); + +const request1: webllm.ChatCompletionRequestStreaming = { + stream: true, + stream_options: { include_usage: true }, + messages: [{ role: "user", content: prompt1 }], + model: selectedModel1, // without specifying it, error will throw due to ambiguity + max_tokens: 128, +}; + +const request2: webllm.ChatCompletionRequestStreaming = { + stream: true, + stream_options: { include_usage: true }, + messages: [{ role: "user", content: prompt2 }], + model: selectedModel2, // without specifying it, error will throw due to ambiguity + max_tokens: 128, +}; + +/** + * Chat completion (OpenAI style) with streaming, with two models in the pipeline. + */ +async function sequentialGeneration() { + const engine = await webllm.CreateWebWorkerMLCEngine( + new Worker(new URL("./worker.ts", import.meta.url), { type: "module" }), + [selectedModel1, selectedModel2], + { initProgressCallback: initProgressCallback }, + ); + + const asyncChunkGenerator1 = await engine.chat.completions.create(request1); + let message1 = ""; + for await (const chunk of asyncChunkGenerator1) { + // console.log(chunk); + message1 += chunk.choices[0]?.delta?.content || ""; + setLabel("generate-label-1", message1); + if (chunk.usage) { + console.log(chunk.usage); // only last chunk has usage + } + // engine.interruptGenerate(); // works with interrupt as well + } + const asyncChunkGenerator2 = await engine.chat.completions.create(request2); + let message2 = ""; + for await (const chunk of asyncChunkGenerator2) { + // console.log(chunk); + message2 += chunk.choices[0]?.delta?.content || ""; + setLabel("generate-label-2", message2); + if (chunk.usage) { + console.log(chunk.usage); // only last chunk has usage + } + // engine.interruptGenerate(); // works with interrupt as well + } + + // without specifying from which model to get message, error will throw due to ambiguity + console.log("Final message 1:\n", await engine.getMessage(selectedModel1)); + console.log("Final message 2:\n", await engine.getMessage(selectedModel2)); +} + +/** + * Chat completion (OpenAI style) with streaming, with two models in the pipeline. + */ +async function parallelGeneration() { + const engine = await webllm.CreateWebWorkerMLCEngine( + new Worker(new URL("./worker.ts", import.meta.url), { type: "module" }), + [selectedModel1, selectedModel2], + { initProgressCallback: initProgressCallback }, + ); + + // We can serve the two requests concurrently + let message1 = ""; + let message2 = ""; + + async function getModel1Response() { + const asyncChunkGenerator1 = await engine.chat.completions.create(request1); + for await (const chunk of asyncChunkGenerator1) { + // console.log(chunk); + message1 += chunk.choices[0]?.delta?.content || ""; + setLabel("generate-label-1", message1); + if (chunk.usage) { + console.log(chunk.usage); // only last chunk has usage + } + // engine.interruptGenerate(); // works with interrupt as well + } + } + + async function getModel2Response() { + const asyncChunkGenerator2 = await engine.chat.completions.create(request2); + for await (const chunk of asyncChunkGenerator2) { + // console.log(chunk); + message2 += chunk.choices[0]?.delta?.content || ""; + setLabel("generate-label-2", message2); + if (chunk.usage) { + console.log(chunk.usage); // only last chunk has usage + } + // engine.interruptGenerate(); // works with interrupt as well + } + } + + await Promise.all([getModel1Response(), getModel2Response()]); + + // without specifying from which model to get message, error will throw due to ambiguity + console.log("Final message 1:\n", await engine.getMessage(selectedModel1)); + console.log("Final message 2:\n", await engine.getMessage(selectedModel2)); +} + +// Pick one to run +sequentialGeneration(); +// parallelGeneration(); diff --git a/examples/multi-models/src/multi_models.html b/examples/multi-models/src/multi_models.html index 1de9c00b..b7fbf479 100644 --- a/examples/multi-models/src/multi_models.html +++ b/examples/multi-models/src/multi_models.html @@ -10,14 +10,21 @@

WebLLM Test Page


-

Prompt

- +

Prompt 1

+ -

Response

- +

Response from model 1

+ +
+ +

Prompt 2

+ + +

Response from model 2

+
- + diff --git a/examples/multi-models/src/multi_models.ts b/examples/multi-models/src/multi_models.ts deleted file mode 100644 index afafe684..00000000 --- a/examples/multi-models/src/multi_models.ts +++ /dev/null @@ -1,76 +0,0 @@ -import * as webllm from "@mlc-ai/web-llm"; - -function setLabel(id: string, text: string) { - const label = document.getElementById(id); - if (label == null) { - throw Error("Cannot find label " + id); - } - label.innerText = text; -} - -/** - * Chat completion (OpenAI style) with streaming, with two models in the pipeline. - */ -async function mainStreaming() { - const initProgressCallback = (report: webllm.InitProgressReport) => { - setLabel("init-label", report.text); - }; - const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k"; - const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k"; - - const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( - [selectedModel1, selectedModel2], - { initProgressCallback: initProgressCallback }, - ); - - const request1: webllm.ChatCompletionRequest = { - stream: true, - stream_options: { include_usage: true }, - messages: [ - { role: "user", content: "Provide me three US states." }, - { role: "assistant", content: "California, New York, Pennsylvania." }, - { role: "user", content: "Two more please!" }, - ], - model: selectedModel1, // without specifying it, error will throw due to ambiguity - }; - - const request2: webllm.ChatCompletionRequest = { - stream: true, - stream_options: { include_usage: true }, - messages: [ - { role: "user", content: "Provide me three cities in NY." }, - { role: "assistant", content: "New York, Binghamton, Buffalo." }, - { role: "user", content: "Two more please!" }, - ], - model: selectedModel2, // without specifying it, error will throw due to ambiguity - }; - - const asyncChunkGenerator1 = await engine.chat.completions.create(request1); - let message = ""; - for await (const chunk of asyncChunkGenerator1) { - console.log(chunk); - message += chunk.choices[0]?.delta?.content || ""; - setLabel("generate-label", message); - if (chunk.usage) { - console.log(chunk.usage); // only last chunk has usage - } - // engine.interruptGenerate(); // works with interrupt as well - } - const asyncChunkGenerator2 = await engine.chat.completions.create(request2); - message += "\n\n"; - for await (const chunk of asyncChunkGenerator2) { - console.log(chunk); - message += chunk.choices[0]?.delta?.content || ""; - setLabel("generate-label", message); - if (chunk.usage) { - console.log(chunk.usage); // only last chunk has usage - } - // engine.interruptGenerate(); // works with interrupt as well - } - - // without specifying from which model to get message, error will throw due to ambiguity - console.log("Final message 1:\n", await engine.getMessage(selectedModel1)); - console.log("Final message 2:\n", await engine.getMessage(selectedModel2)); -} - -mainStreaming(); diff --git a/examples/multi-models/src/worker.ts b/examples/multi-models/src/worker.ts new file mode 100644 index 00000000..6c62240b --- /dev/null +++ b/examples/multi-models/src/worker.ts @@ -0,0 +1,7 @@ +import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; + +// Hookup an engine to a worker handler +const handler = new WebWorkerMLCEngineHandler(); +self.onmessage = (msg: MessageEvent) => { + handler.onmessage(msg); +}; diff --git a/src/message.ts b/src/message.ts index 618dd51a..8be38bb7 100644 --- a/src/message.ts +++ b/src/message.ts @@ -65,6 +65,11 @@ export interface ForwardTokensAndSampleParams { // handler will call reload(). An engine can load multiple models, hence both are list. // TODO(webllm-team): should add appConfig here as well if rigorous. // Fore more, see https://github.com/mlc-ai/web-llm/pull/471 + +// Note on the messages with selectedModelId: +// This is the modelId this request uses. It is needed to identify which async generator +// to instantiate / use, since an engine can load multiple models, thus the handler +// needs to maintain multiple generators. export interface ChatCompletionNonStreamingParams { request: ChatCompletionRequestNonStreaming; modelId: string[]; @@ -72,6 +77,7 @@ export interface ChatCompletionNonStreamingParams { } export interface ChatCompletionStreamInitParams { request: ChatCompletionRequestStreaming; + selectedModelId: string; modelId: string[]; chatOpts?: ChatOptions[]; } @@ -82,6 +88,7 @@ export interface CompletionNonStreamingParams { } export interface CompletionStreamInitParams { request: CompletionCreateParamsStreaming; + selectedModelId: string; modelId: string[]; chatOpts?: ChatOptions[]; } @@ -90,6 +97,9 @@ export interface EmbeddingParams { modelId: string[]; chatOpts?: ChatOptions[]; } +export interface CompletionStreamNextChunkParams { + selectedModelId: string; +} export interface CustomRequestParams { requestName: string; @@ -106,6 +116,7 @@ export type MessageContent = | CompletionNonStreamingParams | CompletionStreamInitParams | EmbeddingParams + | CompletionStreamNextChunkParams | CustomRequestParams | InitProgressReport | LogLevel diff --git a/src/web_worker.ts b/src/web_worker.ts index 6243f828..c683c004 100644 --- a/src/web_worker.ts +++ b/src/web_worker.ts @@ -36,6 +36,7 @@ import { CompletionStreamInitParams, GetMessageParams, RuntimeStatsTextParams, + CompletionStreamNextChunkParams, } from "./message"; import log from "loglevel"; import { MLCEngine } from "./engine"; @@ -44,6 +45,7 @@ import { WorkerEngineModelNotLoadedError, } from "./error"; import { areArraysEqual } from "./utils"; +import { getModelIdToUse } from "./support"; /** * Worker handler that can be used in a WebWorker @@ -70,11 +72,10 @@ export class WebWorkerMLCEngineHandler { chatOpts?: ChatOptions[]; public engine: MLCEngine; - /** ChatCompletion and Completion share the same chunk generator. */ - protected asyncGenerate?: AsyncGenerator< - ChatCompletionChunk | Completion, - void, - void + /** ChatCompletion and Completion share the same chunk generator. Each loaded model has its own. */ + protected loadedModelIdToAsyncGenerator: Map< + string, + AsyncGenerator >; /** @@ -82,6 +83,10 @@ export class WebWorkerMLCEngineHandler { */ constructor() { this.engine = new MLCEngine(); + this.loadedModelIdToAsyncGenerator = new Map< + string, + AsyncGenerator + >(); this.engine.setInitProgressCallback((report: InitProgressReport) => { const msg: WorkerResponse = { kind: "initProgressCallback", @@ -178,10 +183,16 @@ export class WebWorkerMLCEngineHandler { // One-time set up that instantiates the chunk generator in worker this.handleTask(msg.uuid, async () => { const params = msg.content as ChatCompletionStreamInitParams; + // Also ensures params.selectedModelId will match what this.engine selects await this.reloadIfUnmatched(params.modelId, params.chatOpts); - this.asyncGenerate = (await this.engine.chatCompletion( + // Register new async generator for this new request of the model + const curGenerator = (await this.engine.chatCompletion( params.request, )) as AsyncGenerator; + this.loadedModelIdToAsyncGenerator.set( + params.selectedModelId, + curGenerator, + ); onComplete?.(null); return null; }); @@ -203,10 +214,16 @@ export class WebWorkerMLCEngineHandler { // One-time set up that instantiates the chunk generator in worker this.handleTask(msg.uuid, async () => { const params = msg.content as CompletionStreamInitParams; + // Also ensures params.selectedModelId will match what this.engine selects await this.reloadIfUnmatched(params.modelId, params.chatOpts); - this.asyncGenerate = (await this.engine.completion( + // Register new async generator for this new request of the model + const curGenerator = (await this.engine.completion( params.request, )) as AsyncGenerator; + this.loadedModelIdToAsyncGenerator.set( + params.selectedModelId, + curGenerator, + ); onComplete?.(null); return null; }); @@ -217,13 +234,17 @@ export class WebWorkerMLCEngineHandler { // Note: ChatCompletion and Completion share the same chunk generator. // For any subsequent request, we return whatever `next()` yields this.handleTask(msg.uuid, async () => { - if (this.asyncGenerate === undefined) { + const params = msg.content as CompletionStreamNextChunkParams; + const curGenerator = this.loadedModelIdToAsyncGenerator.get( + params.selectedModelId, + ); + if (curGenerator === undefined) { throw Error( - "Chunk generator in worker should be instantiated by now.", + "InternalError: Chunk generator in worker should be instantiated by now.", ); } // Yield the next chunk - const { value } = await this.asyncGenerate.next(); + const { value } = await curGenerator.next(); onComplete?.(value); return value; }); @@ -264,6 +285,10 @@ export class WebWorkerMLCEngineHandler { await this.engine.unload(); this.modelId = undefined; this.chatOpts = undefined; + // This may not be cleaned properly when one asyncGenerator finishes. + // We only clear at unload(), which may not be called upon reload(). + // However, service_worker may skip reload(). Will leave as is for now. + this.loadedModelIdToAsyncGenerator.clear(); onComplete?.(null); return null; }); @@ -334,6 +359,7 @@ export class WebWorkerMLCEngineHandler { /** Check whether frontend expectation matches with backend (modelId and chatOpts). If not (due * to possibly killed service worker), we reload here. + * For more, see https://github.com/mlc-ai/web-llm/pull/533 */ async reloadIfUnmatched( expectedModelId: string[], @@ -613,19 +639,22 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { * the worker which we yield. The last message is `void`, meaning the generator has nothing * to yield anymore. * + * @param selectedModelId: The model of whose async generator to call next() to get next chunk. + * Needed because an engine can load multiple models. + * * @note ChatCompletion and Completion share the same chunk generator. */ - async *asyncGenerate(): AsyncGenerator< - ChatCompletionChunk | Completion, - void, - void - > { + async *asyncGenerate( + selectedModelId: string, + ): AsyncGenerator { // Every time it gets called, sends message to worker, asking for the next chunk while (true) { const msg: WorkerRequest = { kind: "completionStreamNextChunk", uuid: crypto.randomUUID(), - content: null, + content: { + selectedModelId: selectedModelId, + } as CompletionStreamNextChunkParams, }; const ret = await this.getPromise(msg); // If the worker's generator reached the end, it would return a `void` @@ -651,6 +680,14 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { if (this.modelId === undefined) { throw new WorkerEngineModelNotLoadedError(this.constructor.name); } + // Needed for the streaming case. Consolidate model id to specify + // which model's asyncGenerator to instantiate or call next() on. + // Since handler can maintain multiple generators concurrently + const selectedModelId = getModelIdToUse( + this.modelId ? this.modelId : [], + request.model, + "ChatCompletionRequest", + ); if (request.stream) { // First let worker instantiate a generator @@ -659,6 +696,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { uuid: crypto.randomUUID(), content: { request: request, + selectedModelId: selectedModelId, modelId: this.modelId, chatOpts: this.chatOpts, }, @@ -666,7 +704,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { await this.getPromise(msg); // Then return an async chunk generator that resides on the client side - return this.asyncGenerate() as AsyncGenerator< + return this.asyncGenerate(selectedModelId) as AsyncGenerator< ChatCompletionChunk, void, void @@ -701,6 +739,15 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { if (this.modelId === undefined) { throw new WorkerEngineModelNotLoadedError(this.constructor.name); } + // Needed for the streaming case. Consolidate model id to specify + // which model's asyncGenerator to instantiate or call next() on. + // Since handler can maintain multiple generators concurrently + const selectedModelId = getModelIdToUse( + this.modelId ? this.modelId : [], + request.model, + "CompletionCreateParams", + ); + if (request.stream) { // First let worker instantiate a generator const msg: WorkerRequest = { @@ -708,6 +755,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { uuid: crypto.randomUUID(), content: { request: request, + selectedModelId: selectedModelId, modelId: this.modelId, chatOpts: this.chatOpts, }, @@ -715,7 +763,11 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { await this.getPromise(msg); // Then return an async chunk generator that resides on the client side - return this.asyncGenerate() as AsyncGenerator; + return this.asyncGenerate(selectedModelId) as AsyncGenerator< + Completion, + void, + void + >; } // Non streaming case is more straightforward