diff --git a/examples/vision-model/README.md b/examples/vision-model/README.md new file mode 100644 index 00000000..7450aad8 --- /dev/null +++ b/examples/vision-model/README.md @@ -0,0 +1,14 @@ +# WebLLM Get Started App + +This folder provides a minimum demo to show WebLLM API in a webapp setting. +To try it out, you can do the following steps under this folder + +```bash +npm install +npm start +``` + +Note if you would like to hack WebLLM core package. +You can change web-llm dependencies as `"file:../.."`, and follow the build from source +instruction in the project to build webllm locally. This option is only recommended +if you would like to hack WebLLM core package. diff --git a/examples/vision-model/package.json b/examples/vision-model/package.json new file mode 100644 index 00000000..7950e36b --- /dev/null +++ b/examples/vision-model/package.json @@ -0,0 +1,20 @@ +{ + "name": "get-started", + "version": "0.1.0", + "private": true, + "scripts": { + "start": "parcel src/vision_model.html --port 8888", + "build": "parcel build src/vision_model.html --dist-dir lib" + }, + "devDependencies": { + "buffer": "^5.7.1", + "parcel": "^2.8.3", + "process": "^0.11.10", + "tslib": "^2.3.1", + "typescript": "^4.9.5", + "url": "^0.11.3" + }, + "dependencies": { + "@mlc-ai/web-llm": "file:../.." + } +} diff --git a/examples/vision-model/src/vision_model.html b/examples/vision-model/src/vision_model.html new file mode 100644 index 00000000..fb50b530 --- /dev/null +++ b/examples/vision-model/src/vision_model.html @@ -0,0 +1,23 @@ + + + + +

WebLLM Test Page

+ Open console to see output +
+
+ + +

Prompt

+ + +

Response

+ +
+ + + + + diff --git a/examples/vision-model/src/vision_model.ts b/examples/vision-model/src/vision_model.ts new file mode 100644 index 00000000..411e4a70 --- /dev/null +++ b/examples/vision-model/src/vision_model.ts @@ -0,0 +1,94 @@ +import * as webllm from "@mlc-ai/web-llm"; + +function setLabel(id: string, text: string) { + const label = document.getElementById(id); + if (label == null) { + throw Error("Cannot find label " + id); + } + label.innerText = text; +} + +async function main() { + const initProgressCallback = (report: webllm.InitProgressReport) => { + setLabel("init-label", report.text); + }; + const selectedModel = "Phi-3.5-vision-instruct-q4f16_1-MLC"; + const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine( + selectedModel, + { + initProgressCallback: initProgressCallback, + logLevel: "INFO", // specify the log level + }, + ); + + // 1. Single image input (with choices) + const messages: webllm.ChatCompletionMessageParam[] = [ + { + role: "system", + content: + "You are a helpful and honest assistant that answers question concisely.", + }, + { + role: "user", + content: [ + { type: "text", text: "List the items in the image concisely." }, + { + type: "image_url", + image_url: { + url: "https://www.ilankelman.org/stopsigns/australia.jpg", + }, + }, + ], + }, + ]; + const request0: webllm.ChatCompletionRequest = { + stream: false, // can be streaming, same behavior + messages: messages, + }; + const reply0 = await engine.chat.completions.create(request0); + const replyMessage0 = await engine.getMessage(); + console.log(reply0); + console.log(replyMessage0); + console.log(reply0.usage); + + // 2. A follow up text-only question + messages.push({ role: "assistant", content: replyMessage0 }); + messages.push({ role: "user", content: "What is special about this image?" }); + const request1: webllm.ChatCompletionRequest = { + stream: false, // can be streaming, same behavior + messages: messages, + }; + const reply1 = await engine.chat.completions.create(request1); + const replyMessage1 = await engine.getMessage(); + console.log(reply1); + console.log(replyMessage1); + console.log(reply1.usage); + + // 3. A follow up multi-image question + messages.push({ role: "assistant", content: replyMessage1 }); + messages.push({ + role: "user", + content: [ + { type: "text", text: "What about these two images? Answer concisely." }, + { + type: "image_url", + image_url: { url: "https://www.ilankelman.org/eiffeltower.jpg" }, + }, + { + type: "image_url", + image_url: { url: "https://www.ilankelman.org/sunset.jpg" }, + }, + ], + }); + const request2: webllm.ChatCompletionRequest = { + stream: false, // can be streaming, same behavior + messages: messages, + }; + const reply2 = await engine.chat.completions.create(request2); + const replyMessage2 = await engine.getMessage(); + console.log(reply2); + console.log(replyMessage2); + console.log(reply2.usage); +} + +main(); diff --git a/src/config.ts b/src/config.ts index 2e7430d1..6184dc37 100644 --- a/src/config.ts +++ b/src/config.ts @@ -229,6 +229,7 @@ export function postInitAndCheckGenerationConfigValues( export enum ModelType { "LLM", "embedding", + "VLM", // vision-language model } /** @@ -512,6 +513,37 @@ export const prebuiltAppConfig: AppConfig = { context_window_size: 1024, }, }, + // Phi-3.5-vision-instruct + { + model: + "https://huggingface.co/mlc-ai/Phi-3.5-vision-instruct-q4f16_1-MLC", + model_id: "Phi-3.5-vision-instruct-q4f16_1-MLC", + model_lib: + modelLibURLPrefix + + modelVersion + + "/Phi-3.5-vision-instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 3952.18, + low_resource_required: true, + overrides: { + context_window_size: 4096, + }, + model_type: ModelType.VLM, + }, + { + model: + "https://huggingface.co/mlc-ai/Phi-3.5-vision-instruct-q4f32_1-MLC", + model_id: "Phi-3.5-vision-instruct-q4f32_1-MLC", + model_lib: + modelLibURLPrefix + + modelVersion + + "/Phi-3.5-vision-instruct-q4f32_1-ctx4k_cs1k-webgpu.wasm", + vram_required_MB: 5879.84, + low_resource_required: true, + overrides: { + context_window_size: 4096, + }, + model_type: ModelType.VLM, + }, // Mistral variants { model: diff --git a/src/conversation.ts b/src/conversation.ts index 916f2a22..5fa0dd83 100644 --- a/src/conversation.ts +++ b/src/conversation.ts @@ -139,7 +139,7 @@ export class Conversation { */ getPromptArrayLastRound() { if (this.isTextCompletion) { - throw new TextCompletionConversationError("getPromptyArrayLastRound"); + throw new TextCompletionConversationError("getPromptArrayLastRound"); } if (this.messages.length < 3) { throw Error("needs to call getPromptArray for the first message"); @@ -346,7 +346,7 @@ export function getConversationFromChatCompletionRequest( * encounter invalid request. * * @param request The chatCompletionRequest we are about to prefill for. - * @returns The string used to set Conversatoin.function_string + * @returns The string used to set Conversation.function_string */ export function getFunctionCallUsage(request: ChatCompletionRequest): string { if ( diff --git a/src/engine.ts b/src/engine.ts index f201cbec..5a57325c 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -119,6 +119,8 @@ export class MLCEngine implements MLCEngineInterface { >; /** Maps each loaded model's modelId to its chatConfig */ private loadedModelIdToChatConfig: Map; + /** Maps each loaded model's modelId to its modelType */ + private loadedModelIdToModelType: Map; /** Maps each loaded model's modelId to a lock. Ensures * each model only processes one request at at time. */ @@ -141,6 +143,7 @@ export class MLCEngine implements MLCEngineInterface { LLMChatPipeline | EmbeddingPipeline >(); this.loadedModelIdToChatConfig = new Map(); + this.loadedModelIdToModelType = new Map(); this.loadedModelIdToLock = new Map(); this.appConfig = engineConfig?.appConfig || prebuiltAppConfig; this.setLogLevel(engineConfig?.logLevel || DefaultLogLevel); @@ -239,6 +242,7 @@ export class MLCEngine implements MLCEngineInterface { const logitProcessor = this.logitProcessorRegistry?.get(modelId); const tstart = performance.now(); + // look up and parse model record, record model type const modelRecord = findModelRecord(modelId, this.appConfig); const baseUrl = typeof document !== "undefined" @@ -248,7 +252,13 @@ export class MLCEngine implements MLCEngineInterface { if (!modelUrl.startsWith("http")) { modelUrl = new URL(modelUrl, baseUrl).href; } + const modelType = + modelRecord.model_type === undefined || modelRecord.model_type === null + ? ModelType.LLM + : modelRecord.model_type; + this.loadedModelIdToModelType.set(modelId, modelType); + // instantiate cache let configCache: tvmjs.ArtifactCacheTemplate; if (this.appConfig.useIndexedDBCache) { configCache = new tvmjs.ArtifactIndexedDBCache("webllm/config"); @@ -409,6 +419,7 @@ export class MLCEngine implements MLCEngineInterface { } this.loadedModelIdToPipeline.clear(); this.loadedModelIdToChatConfig.clear(); + this.loadedModelIdToModelType.clear(); this.loadedModelIdToLock.clear(); this.deviceLostIsError = true; if (this.reloadController) { @@ -737,7 +748,13 @@ export class MLCEngine implements MLCEngineInterface { // 0. Check model loaded and preprocess inputs const [selectedModelId, selectedPipeline, selectedChatConfig] = this.getLLMStates("ChatCompletionRequest", request.model); - API.postInitAndCheckFieldsChatCompletion(request, selectedModelId); + const selectedModelType = + this.loadedModelIdToModelType.get(selectedModelId); + API.postInitAndCheckFieldsChatCompletion( + request, + selectedModelId, + selectedModelType!, + ); const genConfig: GenerationConfig = { frequency_penalty: request.frequency_penalty, presence_penalty: request.presence_penalty, diff --git a/src/error.ts b/src/error.ts index ae09a90f..496a9dfc 100644 --- a/src/error.ts +++ b/src/error.ts @@ -124,19 +124,39 @@ export class ContentTypeError extends Error { } } -export class UserMessageContentError extends Error { - constructor(content: any) { +export class UnsupportedRoleError extends Error { + constructor(role: string) { + super(`Unsupported role of message: ${role}`); + this.name = "UnsupportedRoleError"; + } +} + +export class UserMessageContentErrorForNonVLM extends Error { + constructor(modelId: string, modelType: string, content: any) { super( - `User message only supports string content for now, but received: ${content}`, + `The model loaded is not of type ModelType.VLM (vision-language model). ` + + `Therefore, user message only supports string content, but received: ${content}\n` + + `Loaded modelId: ${modelId}, modelType: ${modelType}`, ); - this.name = "UserMessageContentError"; + this.name = "UserMessageContentErrorForNonVLM"; } } -export class UnsupportedRoleError extends Error { - constructor(role: string) { - super(`Unsupported role of message: ${role}`); - this.name = "UnsupportedRoleError"; +export class UnsupportedDetailError extends Error { + constructor(detail: string) { + super( + `Currently do not support field image_url.detail, but received: ${detail}`, + ); + this.name = "UnsupportedDetailError"; + } +} + +export class MultipleTextContentError extends Error { + constructor(numTextContent: number) { + super( + `Each message can have at most one text contentPart, but received: ${numTextContent}`, + ); + this.name = "MultipleTextContentError"; } } diff --git a/src/openai_api_protocols/chat_completion.ts b/src/openai_api_protocols/chat_completion.ts index 492b869f..c0c927d6 100644 --- a/src/openai_api_protocols/chat_completion.ts +++ b/src/openai_api_protocols/chat_completion.ts @@ -16,7 +16,11 @@ */ import { MLCEngineInterface } from "../types"; -import { functionCallingModelIds, MessagePlaceholders } from "../config"; +import { + functionCallingModelIds, + MessagePlaceholders, + ModelType, +} from "../config"; import { officialHermes2FunctionCallSchemaArray, hermes2FunctionCallingSystemPrompt, @@ -27,12 +31,14 @@ import { InvalidResponseFormatError, InvalidStreamOptionsError, MessageOrderError, + MultipleTextContentError, SeedTypeError, StreamingCountError, SystemMessageOrderError, + UnsupportedDetailError, UnsupportedFieldsError, UnsupportedModelIdError, - UserMessageContentError, + UserMessageContentErrorForNonVLM, } from "../error"; /* eslint-disable @typescript-eslint/no-namespace */ @@ -371,10 +377,12 @@ export const ChatCompletionRequestUnsupportedFields: Array = []; // all * error or in-place update request. * @param request User's input request. * @param currentModelId The current model loaded that will perform this request. + * @param currentModelType The type of the model loaded, decide what requests can be handled. */ export function postInitAndCheckFields( request: ChatCompletionRequest, currentModelId: string, + currentModelType: ModelType, ): void { // Generation-related checks and post inits are in `postInitAndCheckGenerationConfigValues()` // 1. Check unsupported fields in request @@ -391,10 +399,33 @@ export function postInitAndCheckFields( // 2. Check unsupported messages request.messages.forEach( (message: ChatCompletionMessageParam, index: number) => { + // Check content array messages (that are not simple string) if (message.role === "user" && typeof message.content !== "string") { - // ChatCompletionUserMessageParam - // Remove this when we support image input - throw new UserMessageContentError(message.content); + if (currentModelType !== ModelType.VLM) { + // Only VLM can handle non-string content (i.e. message with image) + throw new UserMessageContentErrorForNonVLM( + currentModelId, + ModelType[currentModelType], + message.content, + ); + } + let numTextContent = 0; + for (let i = 0; i < message.content.length; i++) { + const curContent = message.content[i]; + if (curContent.type === "image_url") { + // Do not support image_url.detail + const detail = curContent.image_url.detail; + if (detail !== undefined && detail !== null) { + throw new UnsupportedDetailError(detail); + } + } else { + numTextContent += 1; + } + } + if (numTextContent > 1) { + // Only one text contentPart per message + throw new MultipleTextContentError(numTextContent); + } } if (message.role === "system" && index !== 0) { throw new SystemMessageOrderError(); diff --git a/tests/openai_chat_completion.test.ts b/tests/openai_chat_completion.test.ts index f7176650..f8dcf27f 100644 --- a/tests/openai_chat_completion.test.ts +++ b/tests/openai_chat_completion.test.ts @@ -8,7 +8,7 @@ import { hermes2FunctionCallingSystemPrompt, officialHermes2FunctionCallSchemaArray, } from "../src/support"; -import { MessagePlaceholders } from "../src/config"; +import { MessagePlaceholders, ModelType } from "../src/config"; import { describe, expect, test } from "@jest/globals"; describe("Check chat completion unsupported requests", () => { @@ -18,7 +18,11 @@ describe("Check chat completion unsupported requests", () => { messages: [{ role: "user", content: "Hello! " }], stream_options: { include_usage: true }, }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); }).toThrow("Only specify stream_options when stream=True."); }); @@ -29,7 +33,11 @@ describe("Check chat completion unsupported requests", () => { messages: [{ role: "user", content: "Hello! " }], stream_options: { include_usage: true }, }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); }).toThrow("Only specify stream_options when stream=True."); }); @@ -42,7 +50,11 @@ describe("Check chat completion unsupported requests", () => { { role: "assistant", content: "Hello! How may I help you today?" }, ], }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); }).toThrow("Last message should be from either `user` or `tool`."); }); @@ -56,7 +68,11 @@ describe("Check chat completion unsupported requests", () => { { role: "system", content: "You are a helpful assistant." }, ], }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); }).toThrow( "System prompt should always be the first message in `messages`.", ); @@ -69,7 +85,11 @@ describe("Check chat completion unsupported requests", () => { n: 2, messages: [{ role: "user", content: "Hello! " }], }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); }).toThrow("When streaming, `n` cannot be > 1."); }); @@ -80,7 +100,11 @@ describe("Check chat completion unsupported requests", () => { max_tokens: 10, seed: 42.2, // Note that Number.isInteger(42.0) is true }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); }).toThrow("`seed` should be an integer, but got"); }); @@ -90,14 +114,17 @@ describe("Check chat completion unsupported requests", () => { messages: [{ role: "user", content: "Hello! " }], response_format: { schema: "some json schema" }, }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); }).toThrow( "JSON schema is only supported with `json_object` response format.", ); }); - // Remove when we support image input (e.g. LlaVA model) - test("Image input is unsupported", () => { + test("image_url.detail is unsupported", () => { expect(() => { const request: ChatCompletionRequest = { messages: [ @@ -107,15 +134,79 @@ describe("Check chat completion unsupported requests", () => { { type: "text", text: "What is in this image?" }, { type: "image_url", - image_url: { url: "https://url_here.jpg" }, + image_url: { + url: "https://url_here.jpg", + detail: "high", + }, }, ], }, ], }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Phi-3.5-vision-instruct-q4f16_1-MLC", + ModelType.VLM, + ); }).toThrow( - "User message only supports string content for now, but received:", + "Currently do not support field image_url.detail, but received: high", + ); + }); + + test("User content cannot have multiple text content parts", () => { + expect(() => { + const request: ChatCompletionRequest = { + messages: [ + { + role: "user", + content: [ + { type: "text", text: "What is in this image?" }, + { + type: "image_url", + image_url: { + url: "https://url_here.jpg", + }, + }, + { type: "text", text: "Thank you." }, + ], + }, + ], + }; + postInitAndCheckFields( + request, + "Phi-3.5-vision-instruct-q4f16_1-MLC", + ModelType.VLM, + ); + }).toThrow( + "Each message can have at most one text contentPart, but received: 2", + ); + }); + + test("Non-VLM cannot support non-string content", () => { + expect(() => { + const request: ChatCompletionRequest = { + messages: [ + { + role: "user", + content: [ + { type: "text", text: "What is in this image?" }, + { + type: "image_url", + image_url: { + url: "https://url_here.jpg", + }, + }, + ], + }, + ], + }; + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); + }).toThrow( + "The model loaded is not of type ModelType.VLM (vision-language model).", ); }); }); @@ -142,7 +233,37 @@ describe("Supported requests", () => { "7660": 5, }, }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); + }); + + test("Support image input, single or multiple images", () => { + const request: ChatCompletionRequest = { + messages: [ + { + role: "user", + content: [ + { type: "text", text: "What is in this image?" }, + { + type: "image_url", + image_url: { url: "https://url_here1.jpg" }, + }, + { + type: "image_url", + image_url: { url: "https://url_here2.jpg" }, + }, + ], + }, + ], + }; + postInitAndCheckFields( + request, + "Phi-3.5-vision-instruct-q4f16_1-MLC", + ModelType.VLM, + ); }); }); @@ -167,7 +288,11 @@ describe("Manual function calling", () => { }, ], }; - postInitAndCheckFields(request, "Hermes-2-Theta-Llama-3-8B-q4f16_1-MLC"); + postInitAndCheckFields( + request, + "Hermes-2-Theta-Llama-3-8B-q4f16_1-MLC", + ModelType.LLM, + ); }); }); @@ -204,7 +329,11 @@ describe("OpenAI API function calling", () => { }, ], }; - postInitAndCheckFields(request, "Llama-3.1-8B-Instruct-q4f32_1-MLC"); + postInitAndCheckFields( + request, + "Llama-3.1-8B-Instruct-q4f32_1-MLC", + ModelType.LLM, + ); }).toThrow( "Llama-3.1-8B-Instruct-q4f32_1-MLC is not supported for ChatCompletionRequest.tools.", ); @@ -222,7 +351,11 @@ describe("OpenAI API function calling", () => { ], response_format: { type: "json_object" }, }; - postInitAndCheckFields(request, "Hermes-2-Pro-Llama-3-8B-q4f16_1-MLC"); + postInitAndCheckFields( + request, + "Hermes-2-Pro-Llama-3-8B-q4f16_1-MLC", + ModelType.LLM, + ); }).toThrow( "When using Hermes-2-Pro function calling via ChatCompletionRequest.tools, " + "cannot specify customized response_format. We will set it for you internally.", @@ -244,7 +377,11 @@ describe("OpenAI API function calling", () => { }, ], }; - postInitAndCheckFields(request, "Hermes-2-Pro-Llama-3-8B-q4f16_1-MLC"); + postInitAndCheckFields( + request, + "Hermes-2-Pro-Llama-3-8B-q4f16_1-MLC", + ModelType.LLM, + ); }).toThrow( "When using Hermes-2-Pro function calling via ChatCompletionRequest.tools, cannot " + "specify customized system prompt.", @@ -266,7 +403,11 @@ describe("OpenAI API function calling", () => { }, ], }; - postInitAndCheckFields(request, "Hermes-2-Pro-Llama-3-8B-q4f16_1-MLC"); + postInitAndCheckFields( + request, + "Hermes-2-Pro-Llama-3-8B-q4f16_1-MLC", + ModelType.LLM, + ); }).toThrow( "When using Hermes-2-Pro function calling via ChatCompletionRequest.tools, cannot " + "specify customized system prompt.", @@ -283,7 +424,11 @@ describe("OpenAI API function calling", () => { }, ], }; - postInitAndCheckFields(request, "Hermes-2-Pro-Llama-3-8B-q4f16_1-MLC"); + postInitAndCheckFields( + request, + "Hermes-2-Pro-Llama-3-8B-q4f16_1-MLC", + ModelType.LLM, + ); expect(request.messages[0].role).toEqual("system"); expect(request.messages[0].content).toEqual( hermes2FunctionCallingSystemPrompt.replace(