diff --git a/examples/README.md b/examples/README.md index c36589f7..ca27bebf 100644 --- a/examples/README.md +++ b/examples/README.md @@ -25,6 +25,7 @@ Note that all examples below run in-browser and use WebGPU as a backend. - [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache - [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()` - [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, and integration with `EmbeddingsInterface` and `MemoryVectorStore` of [Langchain.js](js.langchain.com) +- [multi-models](multi-models): demonstrates loading multiple models in a single engine concurrently #### Advanced OpenAI API Capabilities diff --git a/examples/multi-models/README.md b/examples/multi-models/README.md new file mode 100644 index 00000000..7450aad8 --- /dev/null +++ b/examples/multi-models/README.md @@ -0,0 +1,14 @@ +# WebLLM Get Started App + +This folder provides a minimum demo to show WebLLM API in a webapp setting. +To try it out, you can do the following steps under this folder + +```bash +npm install +npm start +``` + +Note if you would like to hack WebLLM core package. +You can change web-llm dependencies as `"file:../.."`, and follow the build from source +instruction in the project to build webllm locally. This option is only recommended +if you would like to hack WebLLM core package. diff --git a/examples/multi-models/package.json b/examples/multi-models/package.json new file mode 100644 index 00000000..5d7fa7c3 --- /dev/null +++ b/examples/multi-models/package.json @@ -0,0 +1,20 @@ +{ + "name": "get-started", + "version": "0.1.0", + "private": true, + "scripts": { + "start": "parcel src/multi_models.html --port 8888", + "build": "parcel build src/multi_models.html --dist-dir lib" + }, + "devDependencies": { + "buffer": "^5.7.1", + "parcel": "^2.8.3", + "process": "^0.11.10", + "tslib": "^2.3.1", + "typescript": "^4.9.5", + "url": "^0.11.3" + }, + "dependencies": { + "@mlc-ai/web-llm": "file:../.." + } +} diff --git a/examples/multi-models/src/multi_models.html b/examples/multi-models/src/multi_models.html new file mode 100644 index 00000000..1de9c00b --- /dev/null +++ b/examples/multi-models/src/multi_models.html @@ -0,0 +1,23 @@ + + + + +

WebLLM Test Page

+ Open console to see output +
+
+ + +

Prompt

+ + +

Response

+ +
+ + + + + diff --git a/examples/multi-models/src/multi_models.ts b/examples/multi-models/src/multi_models.ts new file mode 100644 index 00000000..afafe684 --- /dev/null +++ b/examples/multi-models/src/multi_models.ts @@ -0,0 +1,76 @@ +import * as webllm from "@mlc-ai/web-llm"; + +function setLabel(id: string, text: string) { + const label = document.getElementById(id); + if (label == null) { + throw Error("Cannot find label " + id); + } + label.innerText = text; +} + +/** + * Chat completion (OpenAI style) with streaming, with two models in the pipeline. + */ +async function mainStreaming() { + const initProgressCallback = (report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); + }; + const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k"; + const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k"; + + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( + [selectedModel1, selectedModel2], + { initProgressCallback: initProgressCallback }, + ); + + const request1: webllm.ChatCompletionRequest = { + stream: true, + stream_options: { include_usage: true }, + messages: [ + { role: "user", content: "Provide me three US states." }, + { role: "assistant", content: "California, New York, Pennsylvania." }, + { role: "user", content: "Two more please!" }, + ], + model: selectedModel1, // without specifying it, error will throw due to ambiguity + }; + + const request2: webllm.ChatCompletionRequest = { + stream: true, + stream_options: { include_usage: true }, + messages: [ + { role: "user", content: "Provide me three cities in NY." }, + { role: "assistant", content: "New York, Binghamton, Buffalo." }, + { role: "user", content: "Two more please!" }, + ], + model: selectedModel2, // without specifying it, error will throw due to ambiguity + }; + + const asyncChunkGenerator1 = await engine.chat.completions.create(request1); + let message = ""; + for await (const chunk of asyncChunkGenerator1) { + console.log(chunk); + message += chunk.choices[0]?.delta?.content || ""; + setLabel("generate-label", message); + if (chunk.usage) { + console.log(chunk.usage); // only last chunk has usage + } + // engine.interruptGenerate(); // works with interrupt as well + } + const asyncChunkGenerator2 = await engine.chat.completions.create(request2); + message += "\n\n"; + for await (const chunk of asyncChunkGenerator2) { + console.log(chunk); + message += chunk.choices[0]?.delta?.content || ""; + setLabel("generate-label", message); + if (chunk.usage) { + console.log(chunk.usage); // only last chunk has usage + } + // engine.interruptGenerate(); // works with interrupt as well + } + + // without specifying from which model to get message, error will throw due to ambiguity + console.log("Final message 1:\n", await engine.getMessage(selectedModel1)); + console.log("Final message 2:\n", await engine.getMessage(selectedModel2)); +} + +mainStreaming(); diff --git a/src/embedding.ts b/src/embedding.ts index c620fb26..d5985ee9 100644 --- a/src/embedding.ts +++ b/src/embedding.ts @@ -265,6 +265,10 @@ export class EmbeddingPipeline { await this.device.sync(); } + async asyncLoadWebGPUPipelines() { + await this.tvm.asyncLoadWebGPUPipelines(this.vm.getInternalModule()); + } + // Performance APIs below /** diff --git a/src/engine.ts b/src/engine.ts index 5cb92f3f..3fe0cbb2 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -18,7 +18,6 @@ import { ChatCompletionRequest, ChatCompletion, ChatCompletionChunk, - ChatCompletionFinishReason, ChatCompletionMessageParam, ChatCompletionRequestNonStreaming, ChatCompletionRequestStreaming, @@ -51,18 +50,22 @@ import { import { cleanModelUrl, findModelRecord, + getModelIdToUse, getToolCallFromOutputMessage, } from "./support"; import { - EngineNotLoadedError, ConfigurationNotInitializedError, DeviceLostError, EmbeddingUnsupportedModelError, FeatureSupportError, MissingModelWasmError, - ModelNotLoadedError, ShaderF16SupportError, WebGPUNotAvailableError, + ReloadArgumentSizeUnmatchedError, + IncorrectPipelineLoadedError, + ReloadModelIdNotUniqueError, + SpecifiedModelNotFoundError, + ModelNotLoadedError, } from "./error"; import { asyncLoadTokenizer } from "./cache_util"; import { EmbeddingPipeline } from "./embedding"; @@ -72,18 +75,20 @@ import { EmbeddingPipeline } from "./embedding"; * * Equivalent to `new webllm.MLCEngine().reload(...)`. * - * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in - * `engineConfig.appConfig`. + * @param modelId model_id of the model to load, either string or string[]. When multiple models + * are provided, we load all models sequentially. Each modelId needs to either be in + * `webllm.prebuiltAppConfig`, or in `engineCOnfig.appConfig`. * @param engineConfig Optionally configures the engine, see `webllm.MLCEngineConfig`. - * @param chatOpts Extra options to override chat behavior specified in `mlc-chat-config.json`. + * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`. + * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i]. * @returns An initialized `WebLLM.MLCEngine` with `modelId` loaded. * @throws Throws error when device lost (mostly due to OOM); users should re-call `CreateMLCEngine()`, * potentially with a smaller model or smaller context window size. */ export async function CreateMLCEngine( - modelId: string, + modelId: string | string[], engineConfig?: MLCEngineConfig, - chatOpts?: ChatOptions, + chatOpts?: ChatOptions | ChatOptions[], ): Promise { const engine = new MLCEngine(engineConfig); await engine.reload(modelId, chatOpts); @@ -104,20 +109,27 @@ export class MLCEngine implements MLCEngineInterface { /** For embeddings.create() */ public embeddings: API.Embeddings; - private currentModelId?: string = undefined; // Model current loaded, undefined if nothing is loaded + /** Maps each loaded model's modelId to its pipeline */ + private loadedModelIdToPipeline: Map< + string, + LLMChatPipeline | EmbeddingPipeline + >; + /** Maps each loaded model's modelId to its chatConfig */ + private loadedModelIdToChatConfig: Map; private logger: (msg: string) => void = log.info; private logitProcessorRegistry?: Map; - private logitProcessor?: LogitProcessor; - private pipeline?: LLMChatPipeline; - private embeddingPipeline?: EmbeddingPipeline; private initProgressCallback?: InitProgressCallback; private interruptSignal = false; private deviceLostIsError = true; // whether device.lost is due to actual error or model reload private reloadController: AbortController | undefined; - private config?: ChatConfig; private appConfig: AppConfig; constructor(engineConfig?: MLCEngineConfig) { + this.loadedModelIdToPipeline = new Map< + string, + LLMChatPipeline | EmbeddingPipeline + >(); + this.loadedModelIdToChatConfig = new Map(); this.appConfig = engineConfig?.appConfig || prebuiltAppConfig; this.setLogLevel(engineConfig?.logLevel || DefaultLogLevel); this.setInitProgressCallback(engineConfig?.initProgressCallback); @@ -150,24 +162,53 @@ export class MLCEngine implements MLCEngineInterface { this.logitProcessorRegistry = logitProcessorRegistry; } + /** + * Set MLCEngine logging output level + * + * @param logLevel The new log level + */ + setLogLevel(logLevel: LogLevel) { + log.setLevel(logLevel); + } + //---------------------------------------- // 1. Model/pipeline loading and unloading //---------------------------------------- - /** - * Reload model `modelId`. - * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in - * `engineConfig.appConfig`. - * @param chatOpts To optionally override the `mlc-chat-config.json` of `modelId`. - * @throws Throws error when device lost (mostly due to OOM); users should re-call reload(), - * potentially with a smaller model or smaller context window size. - */ - async reload(modelId: string, chatOpts?: ChatOptions): Promise { + async reload( + modelId: string | string[], + chatOpts?: ChatOptions | ChatOptions[], + ): Promise { + // 0. Unload all loaded models await this.unload(); + // 1. Convert inputs to arrays + if (!Array.isArray(modelId)) { + modelId = [modelId]; + } + if (chatOpts !== undefined && !Array.isArray(chatOpts)) { + chatOpts = [chatOpts]; + } + // 2. Check whether size matches + if (chatOpts !== undefined && modelId.length !== chatOpts.length) { + throw new ReloadArgumentSizeUnmatchedError( + modelId.length, + chatOpts.length, + ); + } + // 3. Make sure each model in modelId is unique + if (new Set(modelId).size < modelId.length) { + throw new ReloadModelIdNotUniqueError(modelId); + } + // 4. Sequentially load each model + // Single abort should stop all to-be-loaded models this.reloadController = new AbortController(); - try { - await this.reloadInternal(modelId, chatOpts); + for (let i = 0; i < modelId.length; i++) { + await this.reloadInternal( + modelId[i], + chatOpts ? chatOpts[i] : undefined, + ); + } } catch (error) { if (error instanceof DOMException && error.name === "AbortError") { log.warn("Reload() is aborted.", error.message); @@ -183,7 +224,7 @@ export class MLCEngine implements MLCEngineInterface { modelId: string, chatOpts?: ChatOptions, ): Promise { - this.logitProcessor = this.logitProcessorRegistry?.get(modelId); + const logitProcessor = this.logitProcessorRegistry?.get(modelId); const tstart = performance.now(); const modelRecord = findModelRecord(modelId, this.appConfig); @@ -205,7 +246,7 @@ export class MLCEngine implements MLCEngineInterface { // load config const configUrl = new URL("mlc-chat-config.json", modelUrl).href; - this.config = { + const curModelConfig = { ...(await configCache.fetchWithCache( configUrl, "json", @@ -214,6 +255,7 @@ export class MLCEngine implements MLCEngineInterface { ...modelRecord.overrides, ...chatOpts, } as ChatConfig; + this.loadedModelIdToChatConfig.set(modelId, curModelConfig); // load tvm wasm let wasmCache: tvmjs.ArtifactCacheTemplate; @@ -297,7 +339,7 @@ export class MLCEngine implements MLCEngineInterface { const tokenizer = await asyncLoadTokenizer( modelUrl, - this.config, + curModelConfig, this.appConfig, this.logger, ); @@ -309,23 +351,26 @@ export class MLCEngine implements MLCEngineInterface { cacheType, this.reloadController?.signal, ); + + // Instantiate pipeline + // TODO: would be good to somehow check for error when LLMChatPipeline is loaded for an + // embedding model, and prompt user to use ModelRecord.model_type + let newPipeline: LLMChatPipeline | EmbeddingPipeline; if (modelRecord.model_type === ModelType.embedding) { - this.embeddingPipeline = new EmbeddingPipeline( - tvm, - tokenizer, - this.config, - ); + newPipeline = new EmbeddingPipeline(tvm, tokenizer, curModelConfig); } else { - this.pipeline = new LLMChatPipeline( + newPipeline = new LLMChatPipeline( tvm, tokenizer, - this.config, - this.logitProcessor, + curModelConfig, + logitProcessor, ); } - await this.pipeline?.asyncLoadWebGPUPipelines(); - const tend = performance.now(); + await newPipeline.asyncLoadWebGPUPipelines(); + this.loadedModelIdToPipeline.set(modelId, newPipeline); + // Clean up + const tend = performance.now(); if (this.initProgressCallback !== undefined) { const text = "Finish loading on " + gpuLabel; this.initProgressCallback({ @@ -334,28 +379,23 @@ export class MLCEngine implements MLCEngineInterface { text: text, }); } - this.currentModelId = modelId; - if (deviceLostInReload) { throw new DeviceLostError(); } } - /** - * Unloads the currently loaded model and destroy the webgpu device. Waits - * until the webgpu device finishes all submitted work and destroys itself. - * @note This is an asynchronous function. - */ async unload() { this.deviceLostIsError = false; // so that unload() does not trigger device.lost error - this.pipeline?.dispose(); - this.embeddingPipeline?.dispose(); - // Wait until device is actually destroyed so we can safely set deviceLostIsError back to true - await this.pipeline?.sync(); - await this.embeddingPipeline?.sync(); - this.pipeline = undefined; - this.embeddingPipeline = undefined; - this.currentModelId = undefined; + // TODO: can optimize by calling dispose() to all pipelines in parallel. However, need to wait + // for all sync() to finish before proceeding (e.g. naive forEach does not work) + for (const entry of Array.from(this.loadedModelIdToPipeline.entries())) { + const pipeline = entry[1]; + pipeline.dispose(); + // Wait until device is actually destroyed so we can safely set deviceLostIsError back to true + await pipeline.sync(); + } + this.loadedModelIdToPipeline.clear(); + this.loadedModelIdToChatConfig.clear(); this.deviceLostIsError = true; if (this.reloadController) { this.reloadController.abort("Engine.unload() is called."); @@ -369,44 +409,52 @@ export class MLCEngine implements MLCEngineInterface { private async _generate( input: - | string | ChatCompletionRequestNonStreaming | CompletionCreateParamsNonStreaming, - genConfig?: GenerationConfig, + pipeline: LLMChatPipeline, + chatConfig: ChatConfig, + genConfig: GenerationConfig, ): Promise { this.interruptSignal = false; if (genConfig !== undefined) { postInitAndCheckGenerationConfigValues(genConfig); } - await this.prefill(input, genConfig); + await this.prefill(input, pipeline, chatConfig, genConfig); let counter = 1; - while (!this.stopped()) { + while (!pipeline.stopped()) { if (this.interruptSignal) { - this.getPipeline().triggerStop(); + pipeline.triggerStop(); break; } counter += 1; - await this.decode(genConfig); + await this.decode(pipeline, genConfig); } - return await this.getMessage(); + return pipeline.getMessage(); } /** * Similar to `_generate()`; but instead of using callback, we use an async iterable. - * @param request Request for chat completion. - * @param genConfig Generation config extraced from `request`. */ asyncGenerate( request: ChatCompletionRequestStreaming, + model: string, + pipeline: LLMChatPipeline, + chatConfig: ChatConfig, genConfig: GenerationConfig, ): AsyncGenerator; asyncGenerate( request: CompletionCreateParamsStreaming, + model: string, + pipeline: LLMChatPipeline, + chatConfig: ChatConfig, genConfig: GenerationConfig, ): AsyncGenerator; async *asyncGenerate( request: ChatCompletionRequestStreaming | CompletionCreateParamsStreaming, + model: string, + pipeline: LLMChatPipeline, + chatConfig: ChatConfig, genConfig: GenerationConfig, ): AsyncGenerator { // 0. Pre-processing @@ -422,12 +470,11 @@ export class MLCEngine implements MLCEngineInterface { } postInitAndCheckGenerationConfigValues(genConfig); if (request.seed !== null && request.seed !== undefined) { - this.getPipeline().setSeed(request.seed); + pipeline.setSeed(request.seed); } // 1. Helper function that generates the chunk // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const model = this.currentModelId!; const created = Date.now(); const id = crypto.randomUUID(); this.interruptSignal = false; @@ -446,13 +493,13 @@ export class MLCEngine implements MLCEngineInterface { } async function _getChunk( - thisModule: MLCEngine, + selectedPipeline: LLMChatPipeline, ): Promise { // Remove the replacement character (U+FFFD) from the response to handle emojis. // Each emoji is made up of multiples of 4 tokens; when truncated, it is displayed as �, so // we skip this delta until a full emoji is rendered // TODO(Charlie): This does not consider cases of � not being emoji, need to fix with Streamer - const curMessage = await thisModule.getMessage(); + const curMessage = selectedPipeline.getMessage(); const numTrailingReplacementChar = _countTrailingReplacementChar(curMessage); if (numTrailingReplacementChar % 4 !== 0) { @@ -463,7 +510,7 @@ export class MLCEngine implements MLCEngineInterface { prevMessageLength = curMessage.length; const logprobs = request.logprobs ? ({ - content: thisModule.getPipeline().getTokenLogprobArray().slice(-1), // always the last entry + content: selectedPipeline.getTokenLogprobArray().slice(-1), // always the last entry } as ChatCompletionChunk.Choice.Logprobs) : null; if (isChatCompletion) { @@ -502,19 +549,19 @@ export class MLCEngine implements MLCEngineInterface { } // 2. Auto-regressive loop - await this.prefill(request, genConfig); - let curChunk = await _getChunk(this); // prefill produces a chunk + await this.prefill(request, pipeline, chatConfig, genConfig); + let curChunk = await _getChunk(pipeline); // prefill produces a chunk if (curChunk) { yield curChunk; } - while (!this.stopped()) { + while (!pipeline.stopped()) { if (this.interruptSignal) { - this.getPipeline().triggerStop(); + pipeline.triggerStop(); break; } - await this.decode(genConfig); - curChunk = await _getChunk(this); + await this.decode(pipeline, genConfig); + curChunk = await _getChunk(pipeline); if (curChunk) { yield curChunk; } @@ -522,20 +569,20 @@ export class MLCEngine implements MLCEngineInterface { // Reset seed -- we do not want this seed to affect future requests if (request.seed !== null && request.seed !== undefined) { - this.getPipeline().setSeed(Date.now()); + pipeline.setSeed(Date.now()); } // 3. Last chunk empty marking the end // If function calling, use the last chunk to return tool_calls - let finish_reason = this.getFinishReason()!; + let finish_reason = pipeline.getFinishReason()!; let tool_calls: | Array | undefined; - if (this.getFinishReason()! == "stop" && isFunctionCalling) { + if (pipeline.getFinishReason() === "stop" && isFunctionCalling) { // If stopped due to length or abort, cannot output return tool_calls field finish_reason = "tool_calls"; - const outputMessage = await this.getMessage(); + const outputMessage = pipeline.getMessage(); tool_calls = getToolCallFromOutputMessage( outputMessage, /*isStreaming=*/ true, @@ -581,13 +628,10 @@ export class MLCEngine implements MLCEngineInterface { // 4. Usage chunk if (request.stream_options?.include_usage) { - const completion_tokens = - this.getPipeline().getCurRoundDecodingTotalTokens(); - const prompt_tokens = this.getPipeline().getCurRoundPrefillTotalTokens(); - const prefill_tokens_per_s = - this.getPipeline().getCurRoundPrefillTokensPerSec(); - const decode_tokens_per_s = - this.getPipeline().getCurRoundDecodingTokensPerSec(); + const completion_tokens = pipeline.getCurRoundDecodingTotalTokens(); + const prompt_tokens = pipeline.getCurRoundPrefillTotalTokens(); + const prefill_tokens_per_s = pipeline.getCurRoundPrefillTokensPerSec(); + const decode_tokens_per_s = pipeline.getCurRoundDecodingTokensPerSec(); const usage: CompletionUsage = { completion_tokens: completion_tokens, prompt_tokens: prompt_tokens, @@ -649,11 +693,10 @@ export class MLCEngine implements MLCEngineInterface { async chatCompletion( request: ChatCompletionRequest, ): Promise | ChatCompletion> { - // 0. Preprocess inputs - if (!this.currentModelId) { - throw new ModelNotLoadedError(); - } - API.postInitAndCheckFieldsChatCompletion(request, this.currentModelId); + // 0. Check model loaded and preprocess inputs + const [selectedModelId, selectedPipeline, selectedChatConfig] = + this.getLLMStates("ChatCompletionRequest", request.model); + API.postInitAndCheckFieldsChatCompletion(request, selectedModelId); const genConfig: GenerationConfig = { frequency_penalty: request.frequency_penalty, presence_penalty: request.presence_penalty, @@ -669,11 +712,17 @@ export class MLCEngine implements MLCEngineInterface { // 1. If request is streaming, return an AsyncIterable (an iterable version of `_generate()`) if (request.stream) { - return this.asyncGenerate(request, genConfig); + return this.asyncGenerate( + request, + selectedModelId, + selectedPipeline, + selectedChatConfig, + genConfig, + ); } if (request.seed !== null && request.seed !== undefined) { - this.getPipeline().setSeed(request.seed); + selectedPipeline.setSeed(request.seed); } // 2. If request is non-streaming, directly reuse `_generate()` @@ -687,18 +736,23 @@ export class MLCEngine implements MLCEngineInterface { let outputMessage: string; if (this.interruptSignal) { // A single interrupt signal should stop all choices' generations - this.getPipeline().triggerStop(); + selectedPipeline.triggerStop(); outputMessage = ""; } else { - outputMessage = await this._generate(request, genConfig); + outputMessage = await this._generate( + request, + selectedPipeline, + selectedChatConfig, + genConfig, + ); } - let finish_reason = this.getFinishReason()!; + let finish_reason = selectedPipeline.getFinishReason()!; // 3. Post processing for function calling const isFunctionCalling = request.tools !== undefined && request.tools !== null; let tool_calls: Array | undefined; - if (this.getFinishReason()! == "stop" && isFunctionCalling) { + if (selectedPipeline.getFinishReason() === "stop" && isFunctionCalling) { // If stopped due to length or abort, cannot output return tool_calls field finish_reason = "tool_calls"; tool_calls = getToolCallFromOutputMessage( @@ -713,7 +767,7 @@ export class MLCEngine implements MLCEngineInterface { index: i, logprobs: request.logprobs ? ({ - content: this.getPipeline().getTokenLogprobArray(), + content: selectedPipeline.getTokenLogprobArray(), } as ChatCompletion.Choice.Logprobs) : null, message: isFunctionCalling @@ -727,16 +781,16 @@ export class MLCEngine implements MLCEngineInterface { role: "assistant", }, }); - completion_tokens += this.getPipeline().getCurRoundDecodingTotalTokens(); - prompt_tokens += this.getPipeline().getCurRoundPrefillTotalTokens(); - prefill_time += this.getPipeline().getCurRoundPrefillTotalTime(); - decode_time += this.getPipeline().getCurRoundDecodingTotalTime(); + completion_tokens += selectedPipeline.getCurRoundDecodingTotalTokens(); + prompt_tokens += selectedPipeline.getCurRoundPrefillTotalTokens(); + prefill_time += selectedPipeline.getCurRoundPrefillTotalTime(); + decode_time += selectedPipeline.getCurRoundDecodingTotalTime(); } const response: ChatCompletion = { id: crypto.randomUUID(), choices: choices, - model: this.currentModelId, + model: selectedModelId, object: "chat.completion", created: Date.now(), usage: { @@ -752,7 +806,7 @@ export class MLCEngine implements MLCEngineInterface { // Reset seed -- we do not want this seed to affect future requests if (request.seed !== null && request.seed !== undefined) { - this.getPipeline().setSeed(Date.now()); + selectedPipeline.setSeed(Date.now()); } return response; } @@ -777,11 +831,10 @@ export class MLCEngine implements MLCEngineInterface { async completion( request: CompletionCreateParams, ): Promise | Completion> { - // 0. Preprocess inputs - if (!this.currentModelId) { - throw new ModelNotLoadedError(); - } - API.postInitAndCheckFieldsCompletion(request, this.currentModelId); + // 0. Check model loaded and preprocess inputs + const [selectedModelId, selectedPipeline, selectedChatConfig] = + this.getLLMStates("ChatCompletionRequest", request.model); + API.postInitAndCheckFieldsCompletion(request, selectedModelId); const genConfig: GenerationConfig = { frequency_penalty: request.frequency_penalty, presence_penalty: request.presence_penalty, @@ -796,11 +849,17 @@ export class MLCEngine implements MLCEngineInterface { // 1. If request is streaming, return an AsyncIterable (an iterable version of `_generate()`) if (request.stream) { - return this.asyncGenerate(request, genConfig); + return this.asyncGenerate( + request, + selectedModelId, + selectedPipeline, + selectedChatConfig, + genConfig, + ); } if (request.seed !== null && request.seed !== undefined) { - this.getPipeline().setSeed(request.seed); + selectedPipeline.setSeed(request.seed); } // 2. If request is non-streaming, directly reuse `_generate()` @@ -814,12 +873,17 @@ export class MLCEngine implements MLCEngineInterface { let outputMessage: string; if (this.interruptSignal) { // A single interrupt signal should stop all choices' generations - this.getPipeline().triggerStop(); + selectedPipeline.triggerStop(); outputMessage = ""; } else { - outputMessage = await this._generate(request, genConfig); + outputMessage = await this._generate( + request, + selectedPipeline, + selectedChatConfig, + genConfig, + ); } - const finish_reason = this.getFinishReason()!; + const finish_reason = selectedPipeline.getFinishReason()!; choices.push({ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion @@ -827,21 +891,21 @@ export class MLCEngine implements MLCEngineInterface { index: i, logprobs: request.logprobs ? ({ - content: this.getPipeline().getTokenLogprobArray(), + content: selectedPipeline.getTokenLogprobArray(), } as ChatCompletion.Choice.Logprobs) : null, text: request.echo ? request.prompt + outputMessage : outputMessage, }); - completion_tokens += this.getPipeline().getCurRoundDecodingTotalTokens(); - prompt_tokens += this.getPipeline().getCurRoundPrefillTotalTokens(); - prefill_time += this.getPipeline().getCurRoundPrefillTotalTime(); - decode_time += this.getPipeline().getCurRoundDecodingTotalTime(); + completion_tokens += selectedPipeline.getCurRoundDecodingTotalTokens(); + prompt_tokens += selectedPipeline.getCurRoundPrefillTotalTokens(); + prefill_time += selectedPipeline.getCurRoundPrefillTotalTime(); + decode_time += selectedPipeline.getCurRoundDecodingTotalTime(); } const response: Completion = { id: crypto.randomUUID(), choices: choices, - model: this.currentModelId, + model: selectedModelId, object: "text_completion", created: Date.now(), usage: { @@ -857,7 +921,7 @@ export class MLCEngine implements MLCEngineInterface { // Reset seed -- we do not want this seed to affect future requests if (request.seed !== null && request.seed !== undefined) { - this.getPipeline().setSeed(Date.now()); + selectedPipeline.setSeed(Date.now()); } return response; } @@ -866,20 +930,34 @@ export class MLCEngine implements MLCEngineInterface { request: EmbeddingCreateParams, ): Promise { // 0. Preprocess inputs - if (!this.currentModelId) { - throw new ModelNotLoadedError(); + const loadedModelIds: string[] = Array.from( + this.loadedModelIdToPipeline.keys(), + ); + const selectedModelId: string = getModelIdToUse( + loadedModelIds, + request.model, + "EmbeddingCreateParams", + ); + const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId); + if (!(selectedPipeline instanceof EmbeddingPipeline)) { + throw new IncorrectPipelineLoadedError( + selectedModelId, + "EmbeddingPipeline", + "EmbeddingCreateParams", + ); } if ( - findModelRecord(this.currentModelId, this.appConfig).model_type !== + findModelRecord(selectedModelId, this.appConfig).model_type !== ModelType.embedding ) { - throw new EmbeddingUnsupportedModelError(this.currentModelId); + throw new EmbeddingUnsupportedModelError(selectedModelId); } - API.postInitAndCheckFieldsEmbedding(request, this.currentModelId); + API.postInitAndCheckFieldsEmbedding(request, selectedModelId); // 1. Call EmbeddingPipeline to get embeddings - const embedResult: Array> = - await this.getEmbeddingPipeline().embedStep(request.input); + const embedResult: Array> = await selectedPipeline.embedStep( + request.input, + ); // 2. Prepare response const batchSize = embedResult.length; @@ -894,15 +972,13 @@ export class MLCEngine implements MLCEngineInterface { } return { data: data, - model: this.currentModelId, + model: selectedModelId, object: "list", usage: { - prompt_tokens: - this.getEmbeddingPipeline().getCurRoundEmbedTotalTokens(), - total_tokens: this.getEmbeddingPipeline().getCurRoundEmbedTotalTokens(), + prompt_tokens: selectedPipeline.getCurRoundEmbedTotalTokens(), + total_tokens: selectedPipeline.getCurRoundEmbedTotalTokens(), extra: { - prefill_tokens_per_s: - this.getEmbeddingPipeline().getCurRoundEmbedTokensPerSec(), + prefill_tokens_per_s: selectedPipeline.getCurRoundEmbedTokensPerSec(), }, }, }; @@ -950,42 +1026,64 @@ export class MLCEngine implements MLCEngineInterface { return gpuDetectOutput.adapterInfo.vendor; } - //---------------------------------------------- - // 5. Low-level APIs that interact with pipeline - //---------------------------------------------- - private getPipeline(): LLMChatPipeline { - if (this.pipeline === undefined) { - throw new EngineNotLoadedError(); - } - return this.pipeline; - } + //--------------------------------------------------------------- + // 5. Helper for querying currently loaded model/pipeline/config. + // Needed due to possibly multiple loaded models. + //--------------------------------------------------------------- - private getEmbeddingPipeline(): EmbeddingPipeline { - if (this.embeddingPipeline === undefined) { - throw new EngineNotLoadedError(); + /** + * Return the model, its LLMChatPipeline, and ChatConfig to use. Throws error when unclear which + * model to load. + * @param requestName The type of request or API to load the model for. Needed for error throwing. + * @param modelId Model the user specified to load via the request. Required when multiple + * models are loaded + */ + private getLLMStates( + requestName: string, + modelId?: string | null, + ): [string, LLMChatPipeline, ChatConfig] { + // TODO(webllm-team): when more modalities/pipelines are supported, make this method + // generic for different pipelines. e.g. currently embedding() does not use this method + const loadedModelIds: string[] = Array.from( + this.loadedModelIdToPipeline.keys(), + ); + const selectedModelId: string = getModelIdToUse( + loadedModelIds, + modelId, + requestName, + ); + const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId); + if (!(selectedPipeline instanceof LLMChatPipeline)) { + throw new IncorrectPipelineLoadedError( + selectedModelId, + "LLMChatPipeline", + requestName, + ); } - return this.embeddingPipeline; + const selectedChatConfig = + this.loadedModelIdToChatConfig.get(selectedModelId); + if (selectedChatConfig === undefined) { + throw new Error( + `InternalError: chat config not registered for ${selectedModelId}.`, + ); + } + return [selectedModelId, selectedPipeline, selectedChatConfig]; } + //-------------------------------------------------------------------- + // 6. External low-level APIs that directly interacts with a pipeline. + //-------------------------------------------------------------------- + async forwardTokensAndSample( inputIds: Array, isPrefill: boolean, + modelId?: string, ): Promise { - return this.getPipeline().forwardTokensAndSample(inputIds, isPrefill); - } - - /** - * @returns Whether the generation stopped. - */ - stopped(): boolean { - return this.getPipeline().stopped(); - } - - /** - * @returns Finish reason; undefined if generation not started/stopped yet. - */ - getFinishReason(): ChatCompletionFinishReason | undefined { - return this.getPipeline().getFinishReason(); + const [, selectedPipeline] = this.getLLMStates( + "forwardTokensAndSample", + modelId, + ); + return selectedPipeline.forwardTokensAndSample(inputIds, isPrefill); } /** @@ -993,33 +1091,46 @@ export class MLCEngine implements MLCEngineInterface { * * @returns The current output message. */ - async getMessage(): Promise { - return this.getPipeline().getMessage(); - } - - /** - * Set MLCEngine logging output level - * - * @param logLevel The new log level - */ - setLogLevel(logLevel: LogLevel) { - log.setLevel(logLevel); + async getMessage(modelId?: string): Promise { + const [, selectedPipeline] = this.getLLMStates("getMessage", modelId); + return selectedPipeline.getMessage(); } - async runtimeStatsText(): Promise { + async runtimeStatsText(modelId?: string): Promise { log.warn( "WARNING: `runtimeStatsText()` will soon be deprecated. " + "Please use `ChatCompletion.usage` for non-streaming requests, or " + "`ChatCompletionChunk.usage` for streaming requests, enabled by `stream_options`. " + "The only flow that expects to use `runtimeStatsText()` as of now is `forwardTokensAndSample()`.", ); - return this.getPipeline().runtimeStatsText(); + const [, selectedPipeline] = this.getLLMStates("runtimeStatsText", modelId); + return selectedPipeline.runtimeStatsText(); } - async resetChat(keepStats = false) { - this.pipeline?.resetChat(keepStats); + async resetChat(keepStats = false, modelId?: string) { + try { + const [, selectedPipeline] = this.getLLMStates("resetChat", modelId); + selectedPipeline.resetChat(keepStats); + } catch (error) { + if ( + error instanceof ModelNotLoadedError || + error instanceof SpecifiedModelNotFoundError + ) { + // Only allow calling resetChat before pipeline instantiated. + log.debug( + "Caught an expected error in resetChat, treating it as no-op. Error: ", + error, + ); + } else { + throw error; + } + } } + //----------------------------------------------- + // 7. Prefill and decode given an LLMChatPipeline + //----------------------------------------------- + /** * Run a prefill step with a given input. * @@ -1031,36 +1142,40 @@ export class MLCEngine implements MLCEngineInterface { * performing multi-round chatting, so we do not reset, hence reusing KV cache. Otherwise, we * reset every thing, treating the request as something completely new. * - * @param input The input prompt, or `messages` in OpenAI-like APIs. + * @param input The OpenAI-style prompt to prefill. + * @param pipeline The loaded pipeline, hence model, to carry out this prefill. + * @param chatConfig The chat config to use for this model. + * @param genConfig Generation config. */ async prefill( - input: string | ChatCompletionRequest | CompletionCreateParams, - genConfig?: GenerationConfig, + input: ChatCompletionRequest | CompletionCreateParams, + pipeline: LLMChatPipeline, + chatConfig: ChatConfig, + genConfig: GenerationConfig, ) { - if (this.config === undefined) { + // TODO: SPECIFY MODEL TO PERFORM PREFILL, HENCE RETRIEVE CONFIG + if (chatConfig === undefined) { throw new ConfigurationNotInitializedError(); } let input_str: string; let input_role_str: string | undefined; let lastMsgRole = Role.user; - if (typeof input === "string") { - input_str = input; - } else if ("messages" in input) { + if ("messages" in input) { // For ChatCompletionRequest, we prepare input using `messages` // 1. Get new conversation based on request, determine if we are in multiround chatting - const oldConv = this.getPipeline().getConversationObject(); + const oldConv = pipeline.getConversationObject(); const newConv = getConversationFromChatCompletionRequest( input, - this.config, + chatConfig, ); if (!compareConversationObject(oldConv, newConv)) { // Not the same conversation, so not multiround chatting, reset everything (KV cache, etc.) - this.resetChat(); - this.getPipeline().setConversation(newConv); + pipeline.resetChat(); + pipeline.setConversation(newConv); } else if (newConv.messages.length === 0) { // Empty oldConv, and no chat history in newConv, so reset and setConversation - this.resetChat(); - this.getPipeline().setConversation(newConv); + pipeline.resetChat(); + pipeline.setConversation(newConv); } else { log.info("Multiround chatting, reuse KVCache."); } @@ -1076,15 +1191,15 @@ export class MLCEngine implements MLCEngineInterface { } else { // For CompletionCreateParams, the input is just the prompt input_str = input.prompt; - this.resetChat(); + pipeline.resetChat(); const newConv = getConversation( - this.config.conv_template, - this.config.conv_config, + chatConfig.conv_template, + chatConfig.conv_config, true, ); - this.getPipeline().setConversation(newConv); + pipeline.setConversation(newConv); } - return this.getPipeline().prefillStep( + return pipeline.prefillStep( input_str, lastMsgRole, input_role_str, @@ -1095,7 +1210,7 @@ export class MLCEngine implements MLCEngineInterface { /** * Run a decode step to decode the next token. */ - async decode(genConfig?: GenerationConfig) { - return this.getPipeline().decodeStep(genConfig); + async decode(pipeline: LLMChatPipeline, genConfig?: GenerationConfig) { + return pipeline.decodeStep(genConfig); } } diff --git a/src/error.ts b/src/error.ts index ef672892..0c4bd568 100644 --- a/src/error.ts +++ b/src/error.ts @@ -84,9 +84,11 @@ export class WebGPUNotFoundError extends Error { } export class ModelNotLoadedError extends Error { - constructor() { + constructor(requestName: string) { super( - "Model not loaded before calling chatCompletion(). Please ensure you have called `MLCEngine.reload(model)` to load the model before initiating chat operations, or initialize your engine using `CreateMLCEngine()` with a valid model configuration.", + `Model not loaded before trying to complete ${requestName}. Please ensure you have called ` + + `MLCEngine.reload(model) to load the model before initiating APIs, ` + + `or initialize your engine using CreateMLCEngine() with a valid model configuration.`, ); this.name = "ModelNotLoadedError"; } @@ -479,3 +481,63 @@ export class EmbeddingInputEmptyError extends Error { this.name = "EmbeddingInputEmptyError"; } } + +export class ReloadArgumentSizeUnmatchedError extends Error { + constructor(numModelId: number, numChatOpts: number) { + super( + `Expect chatOpts, if specified, to match the size of modelId. However, got ` + + `${numModelId} modelId, but ${numChatOpts} chatOpts.`, + ); + this.name = "ReloadArgumentSizeUnmatchedError"; + } +} + +export class UnclearModelToUseError extends Error { + constructor(loadedModels: string[], requestName: string) { + super( + `Multiple models are loaded in engine. Please specify the model in ${requestName}.\n` + + `Currently loaded models are:\n${loadedModels}`, + ); + this.name = "UnclearModelToUseError"; + } +} + +export class SpecifiedModelNotFoundError extends Error { + constructor( + loadedModels: string[], + requestedModelId: string, + requestName: string, + ) { + super( + `Specified model ${requestedModelId} for ${requestName} is not found in loaded models. ` + + `Please check if the correct model is loaded/specified. ` + + `Currently loaded models are:\n${loadedModels}`, + ); + this.name = "SpecifiedModelNotFoundError"; + } +} + +export class IncorrectPipelineLoadedError extends Error { + constructor( + selectedModelId: string, + expectedPipeline: string, + requestName: string, + ) { + super( + `${requestName} expects model be loaded with ${expectedPipeline}. However, ` + + `${selectedModelId} is not loaded with this pipeline.`, + ); + this.name = "IncorrectPipelineLoadedError"; + } +} + +export class ReloadModelIdNotUniqueError extends Error { + constructor(modelId: string[]) { + super( + `Need to make models in modelId passed to reload() need to be unique. If you want to, ` + + `load copies of the same model, consider making copies of the ModelRecord with ` + + `different model_id. Received modelId: ${modelId}`, + ); + this.name = "ReloadModelIdNotUniqueError"; + } +} diff --git a/src/extension_service_worker.ts b/src/extension_service_worker.ts index d1e42032..97e0dcb0 100644 --- a/src/extension_service_worker.ts +++ b/src/extension_service_worker.ts @@ -7,7 +7,7 @@ import { WebWorkerMLCEngineHandler, WebWorkerMLCEngine, } from "./web_worker"; -import { areChatOptionsEqual } from "./utils"; +import { areArraysEqual, areChatOptionsListEqual } from "./utils"; import { WebGPUNotFoundError } from "./error"; export interface ExtensionMLCEngineConfig extends MLCEngineConfig { @@ -66,8 +66,8 @@ export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler { const params = msg.content as ReloadParams; // If the modelId, chatOpts, and appConfig are the same, immediately return if ( - this.modelId === params.modelId && - areChatOptionsEqual(this.chatOpts, params.chatOpts) + areArraysEqual(this.modelId, params.modelId) && + areChatOptionsListEqual(this.chatOpts, params.chatOpts) ) { log.info("Already loaded the model. Skip loading"); const gpuDetectOutput = await tvmjs.detectGPUDevice(); @@ -104,18 +104,21 @@ export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler { /** * Create a ServiceWorkerMLCEngine. * - * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in - * `engineConfig.appConfig`. + * @param modelId model_id of the model to load, either string or string[]. When multiple models + * are provided, we load all models sequentially. Each modelId needs to either be in + * `webllm.prebuiltAppConfig`, or in `engineCOnfig.appConfig`. * @param engineConfig Optionally configures the engine, see `webllm.MLCEngineConfig` for more. + * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`. + * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i]. * @param keepAliveMs The interval to send keep alive messages to the service worker. * See [Service worker lifecycle](https://developer.chrome.com/docs/extensions/develop/concepts/service-workers/lifecycle#idle-shutdown) * The default is 10s. * @returns An initialized `WebLLM.ServiceWorkerMLCEngine` with `modelId` loaded. */ export async function CreateServiceWorkerMLCEngine( - modelId: string, + modelId: string | string[], engineConfig?: ExtensionMLCEngineConfig, - chatOpts?: ChatOptions, + chatOpts?: ChatOptions | ChatOptions[], keepAliveMs = 10000, ): Promise { const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine( diff --git a/src/message.ts b/src/message.ts index f9cd5775..618dd51a 100644 --- a/src/message.ts +++ b/src/message.ts @@ -1,4 +1,4 @@ -import { AppConfig, ChatOptions, GenerationConfig } from "./config"; +import { AppConfig, ChatOptions } from "./config"; import { InitProgressReport, LogLevel } from "./types"; import { ChatCompletionRequestStreaming, @@ -40,55 +40,55 @@ type RequestKind = type ResponseKind = "return" | "throw" | "initProgressCallback"; export interface ReloadParams { - modelId: string; - chatOpts?: ChatOptions; + modelId: string[]; + chatOpts?: ChatOptions[]; } export interface ResetChatParams { keepStats: boolean; + modelId?: string; +} +export interface GetMessageParams { + modelId?: string; +} +export interface RuntimeStatsTextParams { + modelId?: string; } export interface ForwardTokensAndSampleParams { inputIds: Array; isPrefill: boolean; + modelId?: string; } + +// Notes on the following Params with modelId and chatOpts: +// These fields are the model and chatOpts that the frontend engine expects the backend +// to be loaded with. If not loaded due to web/service worker unexpectedly killed, +// handler will call reload(). An engine can load multiple models, hence both are list. +// TODO(webllm-team): should add appConfig here as well if rigorous. +// Fore more, see https://github.com/mlc-ai/web-llm/pull/471 export interface ChatCompletionNonStreamingParams { request: ChatCompletionRequestNonStreaming; - // The model and chatOpts that the frontend engine expects the backend to be loaded with. - // If not loaded due to service worker unexpectedly killed, handler will call reload(). - // TODO(webllm-team): should add appConfig here as well. - modelId: string; - chatOpts: ChatOptions; + modelId: string[]; + chatOpts?: ChatOptions[]; } export interface ChatCompletionStreamInitParams { request: ChatCompletionRequestStreaming; - // The model and chatOpts that the frontend engine expects the backend to be loaded with. - // If not loaded due to service worker unexpectedly killed, handler will call reload(). - // TODO(webllm-team): should add appConfig here as well. - modelId: string; - chatOpts: ChatOptions; + modelId: string[]; + chatOpts?: ChatOptions[]; } export interface CompletionNonStreamingParams { request: CompletionCreateParamsNonStreaming; - // The model and chatOpts that the frontend engine expects the backend to be loaded with. - // If not loaded due to service worker unexpectedly killed, handler will call reload(). - // TODO(webllm-team): should add appConfig here as well. - modelId: string; - chatOpts: ChatOptions; + modelId: string[]; + chatOpts?: ChatOptions[]; } export interface CompletionStreamInitParams { request: CompletionCreateParamsStreaming; - // The model and chatOpts that the frontend engine expects the backend to be loaded with. - // If not loaded due to service worker unexpectedly killed, handler will call reload(). - // TODO(webllm-team): should add appConfig here as well. - modelId: string; - chatOpts: ChatOptions; + modelId: string[]; + chatOpts?: ChatOptions[]; } export interface EmbeddingParams { request: EmbeddingCreateParams; - // The model and chatOpts that the frontend engine expects the backend to be loaded with. - // If not loaded due to service worker unexpectedly killed, handler will call reload(). - // TODO(webllm-team): should add appConfig here as well. - modelId: string; - chatOpts: ChatOptions; + modelId: string[]; + chatOpts?: ChatOptions[]; } export interface CustomRequestParams { @@ -98,6 +98,8 @@ export interface CustomRequestParams { export type MessageContent = | ReloadParams | ResetChatParams + | GetMessageParams + | RuntimeStatsTextParams | ForwardTokensAndSampleParams | ChatCompletionNonStreamingParams | ChatCompletionStreamInitParams diff --git a/src/openai_api_protocols/chat_completion.ts b/src/openai_api_protocols/chat_completion.ts index ff77015d..492b869f 100644 --- a/src/openai_api_protocols/chat_completion.ts +++ b/src/openai_api_protocols/chat_completion.ts @@ -233,12 +233,13 @@ export interface ChatCompletionRequestBase { */ response_format?: ResponseFormat; - //////////////// BELOW FIELDS NOT SUPPORTED YET //////////////// - /** - * Model to carry out this API. + * ID of the model to use. This equals to `ModelRecord.model_id`, which needs to either be in + * `webllm.prebuiltAppConfig` or in `engineConfig.appConfig`. * - * @note Not supported. Instead, call `CreateMLCEngine(model)` or `engine.reload(model)`. + * @note Call `CreateMLCEngine(model)` or `engine.reload(model)` ahead of time. + * @note If only one model is loaded in the engine, this field is optional. If multiple models + * are loaded, this is required. */ model?: string | null; } @@ -363,7 +364,7 @@ export interface ChatCompletionChunk { usage?: CompletionUsage; } -export const ChatCompletionRequestUnsupportedFields: Array = ["model"]; +export const ChatCompletionRequestUnsupportedFields: Array = []; // all supported as of now /** * Post init and verify whether the input of the request is valid. Thus, this function can throw diff --git a/src/openai_api_protocols/completion.ts b/src/openai_api_protocols/completion.ts index 66e0714c..0ed869cd 100644 --- a/src/openai_api_protocols/completion.ts +++ b/src/openai_api_protocols/completion.ts @@ -182,14 +182,18 @@ export interface CompletionCreateParamsBase { */ top_p?: number | null; - //////////////// BELOW FIELDS NOT SUPPORTED YET //////////////// /** - * Model to carry out this API. + * ID of the model to use. This equals to `ModelRecord.model_id`, which needs to either be in + * `webllm.prebuiltAppConfig` or in `engineConfig.appConfig`. * - * @note Not supported. Instead call `CreateMLCEngine(model)` or `engine.reload(model)` instead. + * @note Call `CreateMLCEngine(model)` or `engine.reload(model)` ahead of time. + * @note If only one model is loaded in the engine, this field is optional. If multiple models + * are loaded, this is required. */ model?: string | null; + //////////////// BELOW FIELDS NOT SUPPORTED YET //////////////// + /** * The suffix that comes after a completion of inserted text. * @@ -305,7 +309,6 @@ export interface CompletionChoice { //////////////////////////////// 3. POST INIT //////////////////////////////// export const CompletionCreateParamsUnsupportedFields: Array = [ - "model", "suffix", "user", "best_of", diff --git a/src/openai_api_protocols/embedding.ts b/src/openai_api_protocols/embedding.ts index f5eeef24..5f623eb4 100644 --- a/src/openai_api_protocols/embedding.ts +++ b/src/openai_api_protocols/embedding.ts @@ -118,18 +118,21 @@ export interface EmbeddingCreateParams { input: string | Array | Array | Array>; /** - * The format to return the embeddings in. + * ID of the model to use. This equals to `ModelRecord.model_id`, which needs to either be in + * `webllm.prebuiltAppConfig` or in `engineConfig.appConfig`. * - * @note Currently only support `float`. + * @note Call `CreateMLCEngine(model)` or `engine.reload(model)` ahead of time. + * @note If only one model is loaded in the engine, this field is optional. If multiple models + * are loaded, this is required. */ - encoding_format?: "float" | "base64"; + model?: string | null; /** - * ID of the model to use. + * The format to return the embeddings in. * - * @note Not supported. Instead, call `CreateMLCEngine(model)` or `engine.reload(model)`. + * @note Currently only support `float`. */ - model?: string; + encoding_format?: "float" | "base64"; // TODO: can support matryoshka embedding models in future, hence allow `dimensions` for those. /** @@ -149,7 +152,6 @@ export interface EmbeddingCreateParams { } export const EmbeddingCreateParamsUnsupportedFields: Array = [ - "model", "dimensions", "user", ]; diff --git a/src/service_worker.ts b/src/service_worker.ts index b1e2d32d..6e1c24ea 100644 --- a/src/service_worker.ts +++ b/src/service_worker.ts @@ -8,7 +8,7 @@ import { WebWorkerMLCEngine, ChatWorker, } from "./web_worker"; -import { areChatOptionsEqual } from "./utils"; +import { areArraysEqual, areChatOptionsListEqual } from "./utils"; import { NoServiceWorkerAPIError, NonWorkerEnvironmentError, @@ -110,8 +110,8 @@ export class ServiceWorkerMLCEngineHandler extends WebWorkerMLCEngineHandler { const params = msg.content as ReloadParams; // If the modelId, chatOpts, and appConfig are the same, immediately return if ( - this.modelId === params.modelId && - areChatOptionsEqual(this.chatOpts, params.chatOpts) + areArraysEqual(this.modelId, params.modelId) && + areChatOptionsListEqual(this.chatOpts, params.chatOpts) ) { log.info("Already loaded the model. Skip loading"); const gpuDetectOutput = await tvmjs.detectGPUDevice(); @@ -181,15 +181,18 @@ export class ServiceWorker implements ChatWorker { /** * Create a ServiceWorkerMLCEngine. * - * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in - * `engineConfig.appConfig`. + * @param modelId model_id of the model to load, either string or string[]. When multiple models + * are provided, we load all models sequentially. Each modelId needs to either be in + * `webllm.prebuiltAppConfig`, or in `engineCOnfig.appConfig`. * @param engineConfig Optionally configures the engine, see `webllm.MLCEngineConfig` for more. + * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`. + * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i]. * @returns An initialized `WebLLM.ServiceWorkerMLCEngine` with `modelId` loaded. */ export async function CreateServiceWorkerMLCEngine( - modelId: string, + modelId: string | string[], engineConfig?: MLCEngineConfig, - chatOpts?: ChatOptions, + chatOpts?: ChatOptions | ChatOptions[], keepAliveMs = 10000, ): Promise { if (!("serviceWorker" in navigator)) { diff --git a/src/support.ts b/src/support.ts index 95b20584..d30a94e0 100644 --- a/src/support.ts +++ b/src/support.ts @@ -1,15 +1,18 @@ /** Util methods. */ import { Tokenizer } from "@mlc-ai/web-tokenizers"; -import { AppConfig, MessagePlaceholders } from "./config"; +import { AppConfig, MessagePlaceholders, ModelRecord } from "./config"; import { ChatCompletionChunk, ChatCompletionMessageToolCall, } from "./openai_api_protocols/index"; import { ModelNotFoundError, + ModelNotLoadedError, + SpecifiedModelNotFoundError, ToolCallOutputInvalidTypeError, ToolCallOutputMissingFieldsError, ToolCallOutputParseError, + UnclearModelToUseError, } from "./error"; /** @@ -199,10 +202,52 @@ export function getToolCallFromOutputMessage( } } -export function findModelRecord(modelId: string, appConfig: AppConfig) { +export function findModelRecord( + modelId: string, + appConfig: AppConfig, +): ModelRecord { const matchedItem = appConfig.model_list.find( (item) => item.model_id == modelId, ); if (matchedItem !== undefined) return matchedItem; throw new ModelNotFoundError(modelId); } + +/** + * Return the model to use given the loaded modelIds and requestModel. Throws error when unclear + * which model to load. + * @param loadedModelIds Models currently loaded in the engine. + * @param requestModel Model the user specified to load via the request. Required when multiple + * models are loaded + * @param requestName The type of request or API to load the model for. Needed for error throwing. + */ +export function getModelIdToUse( + loadedModelIds: string[], + requestModel: string | undefined | null, + requestName: string, +): string { + let selectedModelId: string; + if (loadedModelIds.length === 0) { + throw new ModelNotLoadedError(requestName); + } + if (requestModel) { + // If specified model + if (loadedModelIds.indexOf(requestModel) === -1) { + throw new SpecifiedModelNotFoundError( + loadedModelIds, + requestModel, + requestName, + ); + } else { + selectedModelId = requestModel; + } + } else { + // If not specified + if (loadedModelIds.length > 1) { + throw new UnclearModelToUseError(loadedModelIds, requestName); + } else { + selectedModelId = loadedModelIds[0]; + } + } + return selectedModelId; +} diff --git a/src/types.ts b/src/types.ts index 1dc15899..d7c88846 100644 --- a/src/types.ts +++ b/src/types.ts @@ -99,12 +99,20 @@ export interface MLCEngineInterface { /** * Reload the chat with a new model. * - * @param modelId model_id of the model to load. - * @param chatOpts Extra options to override chat behavior. + * @param modelId model_id of the model to load, either string or string[]. When multiple models + * are provided, we load all models sequentially. Each modelId needs to either be in + * `webllm.prebuiltAppConfig`, or in `engineConfig.appConfig`. + * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`. + * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i]. * @returns A promise when reload finishes. + * @throws Throws error when device lost (mostly due to OOM); users should re-call reload(), + * potentially with a smaller model or smaller context window size. * @note This is an async function. */ - reload: (modelId: string, chatOpts?: ChatOptions) => Promise; + reload: ( + modelId: string | string[], + chatOpts?: ChatOptions | ChatOptions[], + ) => Promise; /** * OpenAI-style API. Generate a chat completion response for the given conversation and @@ -164,9 +172,10 @@ export interface MLCEngineInterface { /** * @returns A text summarizing the runtime stats. + * @param modelId Only required when multiple models are loaded. * @note This is an async function */ - runtimeStatsText: () => Promise; + runtimeStatsText: (modelId?: string) => Promise; /** * Interrupt the generate process if it is already running. @@ -174,22 +183,25 @@ export interface MLCEngineInterface { interruptGenerate: () => void; /** - * Explicitly unload the current model and release the related resources. + * Explicitly unload the currently loaded model(s) and release the related resources. Waits until + * the webgpu device finishes all submitted work and destroys itself. + * @note This is an asynchronous function. */ unload: () => Promise; /** * Reset the current chat session by clear all memories. * @param keepStats: If True, do not reset the statistics. + * @param modelId Only required when multiple models are loaded. */ - resetChat: (keepStats?: boolean) => Promise; + resetChat: (keepStats?: boolean, modelId?: string) => Promise; /** * Get the current generated response. - * + * @param modelId Only required when multiple models are loaded. * @returns The current output message. */ - getMessage: () => Promise; + getMessage: (modelId?: string) => Promise; /** * Returns the device's maxStorageBufferBindingSize, can be used to guess whether the device @@ -210,12 +222,14 @@ export interface MLCEngineInterface { * * @param inputIds The input tokens. * @param isPrefill True if prefill, false if decode; only used for statistics. + * @param modelId Only required when multiple models are loaded. * @returns Next token sampled. * @note This is an async function. */ forwardTokensAndSample( inputIds: Array, isPrefill: boolean, + modelId?: string, ): Promise; /** diff --git a/src/utils.ts b/src/utils.ts index e28697a7..7c688927 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,7 +1,7 @@ import { AppConfig, ChatOptions, ModelRecord } from "./config"; // Helper function to compare two arrays -function areArraysEqual(arr1?: Array, arr2?: Array): boolean { +export function areArraysEqual(arr1?: Array, arr2?: Array): boolean { if (!arr1 && !arr2) return true; if (!arr1 || !arr2) return false; if (arr1.length !== arr2.length) return false; @@ -120,3 +120,28 @@ export function areChatOptionsEqual( // If all checks passed, the options are equal return true; } + +export function areChatOptionsListEqual( + options1?: ChatOptions[], + options2?: ChatOptions[], +): boolean { + if (options1 && options2) { + // Both defined, need to compare + if (options1.length !== options2.length) { + return false; + } else { + for (let i = 0; i < options1.length; i++) { + if (!areChatOptionsEqual(options1[i], options2[i])) { + return false; + } + } + return true; + } + } else if (!options1 && !options2) { + // Both undefined, equal + return true; + } else { + // One defined, other not + return false; + } +} diff --git a/src/web_worker.ts b/src/web_worker.ts index 5f4ec6f3..6243f828 100644 --- a/src/web_worker.ts +++ b/src/web_worker.ts @@ -34,6 +34,8 @@ import { CompletionNonStreamingParams, EmbeddingParams, CompletionStreamInitParams, + GetMessageParams, + RuntimeStatsTextParams, } from "./message"; import log from "loglevel"; import { MLCEngine } from "./engine"; @@ -41,6 +43,7 @@ import { UnknownMessageKindError, WorkerEngineModelNotLoadedError, } from "./error"; +import { areArraysEqual } from "./utils"; /** * Worker handler that can be used in a WebWorker @@ -56,14 +59,15 @@ import { export class WebWorkerMLCEngineHandler { /** * The modelId and chatOpts that the underlying engine (backend) is currently loaded with. + * An engine can be loaded with multiple models, so modelId and chatOpts are lists. * * TODO(webllm-team): This is always in-sync with `this.engine` unless device is lost due to * unexpected reason. Therefore, we should get it from `this.engine` directly and make handler * stateless. Besides, consider if we should add appConfig, or use engine's API to find the * corresponding model record rather than relying on just the modelId. */ - modelId?: string; - chatOpts?: ChatOptions; + modelId?: string[]; + chatOpts?: ChatOptions[]; public engine: MLCEngine; /** ChatCompletion and Completion share the same chunk generator. */ @@ -151,6 +155,7 @@ export class WebWorkerMLCEngineHandler { const res = await this.engine.forwardTokensAndSample( params.inputIds, params.isPrefill, + params.modelId, ); onComplete?.(res); return res; @@ -238,7 +243,8 @@ export class WebWorkerMLCEngineHandler { } case "runtimeStatsText": { this.handleTask(msg.uuid, async () => { - const res = await this.engine.runtimeStatsText(); + const params = msg.content as RuntimeStatsTextParams; + const res = await this.engine.runtimeStatsText(params.modelId); onComplete?.(res); return res; }); @@ -266,7 +272,7 @@ export class WebWorkerMLCEngineHandler { case "resetChat": { this.handleTask(msg.uuid, async () => { const params = msg.content as ResetChatParams; - await this.engine.resetChat(params.keepStats); + await this.engine.resetChat(params.keepStats, params.modelId); onComplete?.(null); return null; }); @@ -290,7 +296,8 @@ export class WebWorkerMLCEngineHandler { } case "getMessage": { this.handleTask(msg.uuid, async () => { - const res = await this.engine.getMessage(); + const params = msg.content as GetMessageParams; + const res = await this.engine.getMessage(params.modelId); onComplete?.(res); return res; }); @@ -329,10 +336,11 @@ export class WebWorkerMLCEngineHandler { * to possibly killed service worker), we reload here. */ async reloadIfUnmatched( - expectedModelId: string, - expectedChatOpts: ChatOptions, + expectedModelId: string[], + expectedChatOpts?: ChatOptions[], ) { - if (this.modelId !== expectedModelId) { + // TODO: should we also check expectedChatOpts here? + if (!areArraysEqual(this.modelId, expectedModelId)) { log.warn( "WebWorkerMLCEngine expects model is loaded in WebWorkerMLCEngineHandler, " + "but it is not. This may due to web/service worker is unexpectedly killed.\n" + @@ -353,19 +361,22 @@ export interface ChatWorker { * * Equivalent to `new webllm.WebWorkerMLCEngine(worker).reload(...)`. * - * @param worker The worker that holds the actual MLCEngine, intialized with `new Worker()`. - * @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in - * `engineConfig.appConfig`. + * @param worker The worker that holds the actual MLCEngine, initialized with `new Worker()`. + * @param modelId model_id of the model to load, either string or string[]. When multiple models + * are provided, we load all models sequentially. Each modelId needs to either be in + * `webllm.prebuiltAppConfig`, or in `engineCOnfig.appConfig`. * @param engineConfig Optionally configures the engine, see `webllm.MLCEngineConfig` for more. + * @param chatOpts Extra options to optionally override the `mlc-chat-config.json` of `modelId`. + * The size of which needs to match that of `modelId`; chatOpts[i] will be used for modelId[i]. * @returns An initialized `WebLLM.WebWorkerMLCEngine` with `modelId` loaded. * * @note engineConfig.logitProcessorRegistry is ignored for `CreateWebWorkMLCEngine()`. */ export async function CreateWebWorkerMLCEngine( worker: any, - modelId: string, + modelId: string | string[], engineConfig?: MLCEngineConfig, - chatOpts?: ChatOptions, + chatOpts?: ChatOptions | ChatOptions[], ): Promise { const webWorkerMLCEngine = new WebWorkerMLCEngine(worker, engineConfig); await webWorkerMLCEngine.reload(modelId, chatOpts); @@ -395,9 +406,10 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { * The modelId and chatOpts that the frontend expects the backend engine is currently loaded * with. Needed for service worker. It is the backend and handler's job to match up with the * expectation despite the web/service worker possibly being killed. + * Since an engine can load multiple models, both modelId and chatOpts are lists. */ - modelId?: string; - chatOpts?: ChatOptions; + modelId?: string[]; + chatOpts?: ChatOptions[]; private initProgressCallback?: InitProgressCallback; private pendingPromise = new Map void>(); @@ -481,7 +493,18 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { return promise; } - async reload(modelId: string, chatOpts?: ChatOptions): Promise { + async reload( + modelId: string | string[], + chatOpts?: ChatOptions | ChatOptions[], + ): Promise { + // Always convert modelId and chatOpts to lists internally for ease of manipulation + if (!Array.isArray(modelId)) { + modelId = [modelId]; + } + if (chatOpts !== undefined && !Array.isArray(chatOpts)) { + chatOpts = [chatOpts]; + } + const msg: WorkerRequest = { kind: "reload", uuid: crypto.randomUUID(), @@ -513,20 +536,24 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { return await this.getPromise(msg); } - async getMessage(): Promise { + async getMessage(modelId?: string): Promise { const msg: WorkerRequest = { kind: "getMessage", uuid: crypto.randomUUID(), - content: null, + content: { + modelId: modelId, + }, }; return await this.getPromise(msg); } - async runtimeStatsText(): Promise { + async runtimeStatsText(modelId?: string): Promise { const msg: WorkerRequest = { kind: "runtimeStatsText", uuid: crypto.randomUUID(), - content: null, + content: { + modelId: modelId, + }, }; return await this.getPromise(msg); } @@ -551,12 +578,13 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { this.chatOpts = undefined; } - async resetChat(keepStats = false): Promise { + async resetChat(keepStats = false, modelId?: string): Promise { const msg: WorkerRequest = { kind: "resetChat", uuid: crypto.randomUUID(), content: { keepStats: keepStats, + modelId: modelId, }, }; await this.getPromise(msg); @@ -565,6 +593,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { async forwardTokensAndSample( inputIds: Array, isPrefill: boolean, + modelId?: string, ): Promise { const msg: WorkerRequest = { kind: "forwardTokensAndSample", @@ -572,6 +601,7 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { content: { inputIds: inputIds, isPrefill: isPrefill, + modelId: modelId, }, }; return await this.getPromise(msg); diff --git a/tests/openai_chat_completion.test.ts b/tests/openai_chat_completion.test.ts index 96cd8d98..f7176650 100644 --- a/tests/openai_chat_completion.test.ts +++ b/tests/openai_chat_completion.test.ts @@ -33,21 +33,6 @@ describe("Check chat completion unsupported requests", () => { }).toThrow("Only specify stream_options when stream=True."); }); - test("High-level unsupported fields", () => { - expect(() => { - const request: ChatCompletionRequest = { - model: "phi-2-q4f32_1-MLC", // this raises error - messages: [ - { role: "system", content: "You are a helpful assistant." }, - { role: "user", content: "Hello! " }, - ], - }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); - }).toThrow( - "The following fields in ChatCompletionRequest are not yet supported", - ); - }); - test("Last message should be from user or tool", () => { expect(() => { const request: ChatCompletionRequest = { diff --git a/tests/openai_completion.test.ts b/tests/openai_completion.test.ts index bb22e8df..00b12ffb 100644 --- a/tests/openai_completion.test.ts +++ b/tests/openai_completion.test.ts @@ -55,16 +55,6 @@ describe("Check completion unsupported requests", () => { }); test("High-level unsupported fields", () => { - expect(() => { - const request: CompletionCreateParams = { - model: "phi-2-q4f32_1-MLC", // this raises error - prompt: "Hello, ", - }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); - }).toThrow( - "The following fields in CompletionCreateParams are not yet supported", - ); - expect(() => { const request: CompletionCreateParams = { prompt: "Hello, ", diff --git a/tests/openai_embeddings.test.ts b/tests/openai_embeddings.test.ts index dd704ad4..09fdec5c 100644 --- a/tests/openai_embeddings.test.ts +++ b/tests/openai_embeddings.test.ts @@ -98,17 +98,6 @@ describe("Check embeddings unsupported requests", () => { }).toThrow(new EmbeddingUnsupportedEncodingFormatError()); }); - test("model", () => { - expect(() => { - const request: EmbeddingCreateParams = { - input: ["Hello", "Hi"], - encoding_format: "float", - model: "snowflake-arctic-embed-m-q0f32-MLC", - }; - postInitAndCheckFields(request, "snowflake-arctic-embed-m-q0f32-MLC"); - }).toThrow("The following fields in"); - }); - test("user", () => { expect(() => { const request: EmbeddingCreateParams = { diff --git a/tests/util.test.ts b/tests/util.test.ts index 8cdc7955..f35e729d 100644 --- a/tests/util.test.ts +++ b/tests/util.test.ts @@ -1,4 +1,12 @@ -import { cleanModelUrl, getTopProbs } from "../src/support"; +import { ChatOptions } from "../src/config"; +import { + ModelNotLoadedError, + SpecifiedModelNotFoundError, + UnclearModelToUseError, +} from "../src/error"; +import { cleanModelUrl, getModelIdToUse, getTopProbs } from "../src/support"; +import { areChatOptionsListEqual } from "../src/utils"; +import { MLCEngine } from "../src/engine"; describe("Check getTopLogprobs correctness", () => { test("Correctness test 1", () => { @@ -56,3 +64,252 @@ describe("Test clean model URL", () => { expect(output).toEqual(expected); }); }); + +describe("Test getModelIdToUse", () => { + test("Specified model not found", () => { + const loadedModelIds = ["a", "b", "c"]; + const requestModel = "d"; + const requestName = "ChatCompletionRequest"; + expect(() => { + getModelIdToUse(loadedModelIds, requestModel, requestName); + }).toThrow( + new SpecifiedModelNotFoundError( + loadedModelIds, + requestModel, + requestName, + ), + ); + }); + + test("No model loaded", () => { + const loadedModelIds: string[] = []; + const requestModel = "d"; + const requestName = "ChatCompletionRequest"; + expect(() => { + getModelIdToUse(loadedModelIds, requestModel, requestName); + }).toThrow(new ModelNotLoadedError(requestName)); + }); + + test("Unclear what model to use, undefined", () => { + const loadedModelIds = ["a", "b", "c"]; + const requestModel = undefined; + const requestName = "ChatCompletionRequest"; + expect(() => { + getModelIdToUse(loadedModelIds, requestModel, requestName); + }).toThrow(new UnclearModelToUseError(loadedModelIds, requestName)); + }); + + test("Unclear what model to use, null", () => { + const loadedModelIds = ["a", "b", "c"]; + const requestModel = null; + const requestName = "ChatCompletionRequest"; + expect(() => { + getModelIdToUse(loadedModelIds, requestModel, requestName); + }).toThrow(new UnclearModelToUseError(loadedModelIds, requestName)); + }); + + test("Valid config, unspecified request model", () => { + const loadedModelIds = ["a"]; + const requestModel = null; + const requestName = "ChatCompletionRequest"; + const selectedModelId = getModelIdToUse( + loadedModelIds, + requestModel, + requestName, + ); + expect(selectedModelId).toEqual("a"); + }); + + test("Valid config, specified request model", () => { + const loadedModelIds = ["a"]; + const requestModel = "a"; + const requestName = "ChatCompletionRequest"; + const selectedModelId = getModelIdToUse( + loadedModelIds, + requestModel, + requestName, + ); + expect(selectedModelId).toEqual("a"); + }); + + test("Valid config, specified request model, multi models loaded", () => { + const loadedModelIds = ["a", "b", "c"]; + const requestModel = "c"; + const requestName = "ChatCompletionRequest"; + const selectedModelId = getModelIdToUse( + loadedModelIds, + requestModel, + requestName, + ); + expect(selectedModelId).toEqual("c"); + }); + + // Cannot test MLCEngine.getLLMStates E2E because `instanceof LLMChatPipeline` would not pass + // with dummy pipeline variables + test("E2E test with MLCEngine not loading a model for APIs", () => { + const engine = new MLCEngine(); + expect(async () => { + await engine.chatCompletion({ + messages: [{ role: "user", content: "hi" }], + }); + }).rejects.toThrow(new ModelNotLoadedError("ChatCompletionRequest")); + expect(async () => { + await engine.getMessage(); + }).rejects.toThrow(new ModelNotLoadedError("getMessage")); + + // resetChat should not throw error because it is allowed to resetChat before pipeline + // established, as a no-op + expect(async () => { + await engine.resetChat(); + }).not.toThrow(new ModelNotLoadedError("resetChat")); + }); + + test("E2E test with MLCEngine with two models without specifying a model", () => { + const engine = new MLCEngine() as any; + engine.loadedModelIdToPipeline = new Map(); + engine.loadedModelIdToPipeline.set("model1", "dummyLLMChatPipeline"); + engine.loadedModelIdToPipeline.set("model2", "dummyLLMChatPipeline"); + const loadedModelIds = ["model1", "model2"]; + + expect(async () => { + await engine.chatCompletion({ + messages: [{ role: "user", content: "hi" }], + }); + }).rejects.toThrow( + new UnclearModelToUseError(loadedModelIds, "ChatCompletionRequest"), + ); + expect(async () => { + await engine.getMessage(); + }).rejects.toThrow( + new UnclearModelToUseError(loadedModelIds, "getMessage"), + ); + expect(async () => { + await engine.resetChat(); + }).rejects.toThrow(new UnclearModelToUseError(loadedModelIds, "resetChat")); + }); + + test("E2E test with MLCEngine with two models specifying wrong model", () => { + const engine = new MLCEngine() as any; + engine.loadedModelIdToPipeline = new Map(); + engine.loadedModelIdToPipeline.set("model1", "dummyLLMChatPipeline"); + engine.loadedModelIdToPipeline.set("model2", "dummyLLMChatPipeline"); + const loadedModelIds = ["model1", "model2"]; + const requestedModelId = "model3"; + + expect(async () => { + await engine.chatCompletion({ + messages: [{ role: "user", content: "hi" }], + model: requestedModelId, + }); + }).rejects.toThrow( + new SpecifiedModelNotFoundError( + loadedModelIds, + requestedModelId, + "ChatCompletionRequest", + ), + ); + expect(async () => { + await engine.getMessage(requestedModelId); + }).rejects.toThrow( + new SpecifiedModelNotFoundError( + loadedModelIds, + requestedModelId, + "getMessage", + ), + ); + expect(async () => { + await engine.runtimeStatsText(requestedModelId); + }).rejects.toThrow( + new SpecifiedModelNotFoundError( + loadedModelIds, + requestedModelId, + "runtimeStatsText", + ), + ); + + // resetChat should not throw error because it is allowed to resetChat before pipeline + // established, as a no-op + expect(async () => { + await engine.resetChat(false, requestedModelId); + }).not.toThrow( + new SpecifiedModelNotFoundError( + loadedModelIds, + requestedModelId, + "resetChat", + ), + ); + }); +}); + +describe("Test areChatOptionsListEqual", () => { + const dummyChatOpts1: ChatOptions = { tokenizer_files: ["a", "b"] }; + const dummyChatOpts2: ChatOptions = {}; + const dummyChatOpts3: ChatOptions = { tokenizer_files: ["a", "b"] }; + const dummyChatOpts4: ChatOptions = { + tokenizer_files: ["a", "b"], + top_p: 0.5, + }; + + test("Two undefined", () => { + const options1: ChatOptions[] | undefined = undefined; + const options2: ChatOptions[] | undefined = undefined; + expect(areChatOptionsListEqual(options1, options2)).toEqual(true); + }); + + test("One undefined", () => { + const options1: ChatOptions[] | undefined = [dummyChatOpts1]; + const options2: ChatOptions[] | undefined = undefined; + expect(areChatOptionsListEqual(options1, options2)).toEqual(false); + }); + + test("Both defined, not equal", () => { + const options1: ChatOptions[] | undefined = [dummyChatOpts1]; + const options2: ChatOptions[] | undefined = [dummyChatOpts2]; + expect(areChatOptionsListEqual(options1, options2)).toEqual(false); + }); + + test("Different size", () => { + const options1: ChatOptions[] | undefined = [ + dummyChatOpts1, + dummyChatOpts3, + ]; + const options2: ChatOptions[] | undefined = [dummyChatOpts2]; + expect(areChatOptionsListEqual(options1, options2)).toEqual(false); + }); + + test("Same size, not equal 1", () => { + const options1: ChatOptions[] | undefined = [ + dummyChatOpts1, + dummyChatOpts3, + ]; + const options2: ChatOptions[] | undefined = [ + dummyChatOpts1, + dummyChatOpts2, + ]; + expect(areChatOptionsListEqual(options1, options2)).toEqual(false); + }); + + test("Same size, not equal 2", () => { + const options1: ChatOptions[] | undefined = [ + dummyChatOpts1, + dummyChatOpts3, + ]; + const options2: ChatOptions[] | undefined = [ + dummyChatOpts1, + dummyChatOpts4, + ]; + expect(areChatOptionsListEqual(options1, options2)).toEqual(false); + }); + + test("Same size, equal", () => { + const options1: ChatOptions[] | undefined = [ + dummyChatOpts1, + dummyChatOpts3, + ]; + const options2: ChatOptions[] | undefined = [ + dummyChatOpts3, + dummyChatOpts1, + ]; + expect(areChatOptionsListEqual(options1, options2)).toEqual(true); + }); +});