Skip to content

Commit

Permalink
Support image input in chatCompletionRequest
Browse files Browse the repository at this point in the history
- Enable content being an array, which can have image_url
- Introduce ModelType.VLM so that only VLM can handle non-string message content
- Thus pass in loadedModelType to postInitCheck, hence add loadedModelIdToModelType in Engine
- Change unit tests correspondingly
  • Loading branch information
CharlieFRuan committed Sep 18, 2024
1 parent 429e719 commit fad3df9
Show file tree
Hide file tree
Showing 10 changed files with 432 additions and 36 deletions.
14 changes: 14 additions & 0 deletions examples/vision-model/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# WebLLM Get Started App

This folder provides a minimum demo to show WebLLM API in a webapp setting.
To try it out, you can do the following steps under this folder

```bash
npm install
npm start
```

Note if you would like to hack WebLLM core package.
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
instruction in the project to build webllm locally. This option is only recommended
if you would like to hack WebLLM core package.
20 changes: 20 additions & 0 deletions examples/vision-model/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"name": "get-started",
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/vision_model.html --port 8888",
"build": "parcel build src/vision_model.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "file:../.."
}
}
23 changes: 23 additions & 0 deletions examples/vision-model/src/vision_model.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<!doctype html>
<html>
<script>
webLLMGlobal = {};
</script>
<body>
<h2>WebLLM Test Page</h2>
Open console to see output
<br />
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./vision_model.ts"></script>
</body>
</html>
94 changes: 94 additions & 0 deletions examples/vision-model/src/vision_model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import * as webllm from "@mlc-ai/web-llm";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

async function main() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};
const selectedModel = "Phi-3.5-vision-instruct-q4f16_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
logLevel: "INFO", // specify the log level
},
);

// 1. Single image input (with choices)
const messages: webllm.ChatCompletionMessageParam[] = [
{
role: "system",
content:
"You are a helpful and honest assistant that answers question concisely.",
},
{
role: "user",
content: [
{ type: "text", text: "List the items in the image concisely." },
{
type: "image_url",
image_url: {
url: "https://www.ilankelman.org/stopsigns/australia.jpg",
},
},
],
},
];
const request0: webllm.ChatCompletionRequest = {
stream: false, // can be streaming, same behavior
messages: messages,
};
const reply0 = await engine.chat.completions.create(request0);
const replyMessage0 = await engine.getMessage();
console.log(reply0);
console.log(replyMessage0);
console.log(reply0.usage);

// 2. A follow up text-only question
messages.push({ role: "assistant", content: replyMessage0 });
messages.push({ role: "user", content: "What is special about this image?" });
const request1: webllm.ChatCompletionRequest = {
stream: false, // can be streaming, same behavior
messages: messages,
};
const reply1 = await engine.chat.completions.create(request1);
const replyMessage1 = await engine.getMessage();
console.log(reply1);
console.log(replyMessage1);
console.log(reply1.usage);

// 3. A follow up multi-image question
messages.push({ role: "assistant", content: replyMessage1 });
messages.push({
role: "user",
content: [
{ type: "text", text: "What about these two images? Answer concisely." },
{
type: "image_url",
image_url: { url: "https://www.ilankelman.org/eiffeltower.jpg" },
},
{
type: "image_url",
image_url: { url: "https://www.ilankelman.org/sunset.jpg" },
},
],
});
const request2: webllm.ChatCompletionRequest = {
stream: false, // can be streaming, same behavior
messages: messages,
};
const reply2 = await engine.chat.completions.create(request2);
const replyMessage2 = await engine.getMessage();
console.log(reply2);
console.log(replyMessage2);
console.log(reply2.usage);
}

main();
32 changes: 32 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ export function postInitAndCheckGenerationConfigValues(
export enum ModelType {
"LLM",
"embedding",
"VLM", // vision-language model
}

/**
Expand Down Expand Up @@ -512,6 +513,37 @@ export const prebuiltAppConfig: AppConfig = {
context_window_size: 1024,
},
},
// Phi-3.5-vision-instruct
{
model:
"https://huggingface.co/mlc-ai/Phi-3.5-vision-instruct-q4f16_1-MLC",
model_id: "Phi-3.5-vision-instruct-q4f16_1-MLC",
model_lib:
modelLibURLPrefix +
modelVersion +
"/Phi-3.5-vision-instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm",
vram_required_MB: 3952.18,
low_resource_required: true,
overrides: {
context_window_size: 4096,
},
model_type: ModelType.VLM,
},
{
model:
"https://huggingface.co/mlc-ai/Phi-3.5-vision-instruct-q4f32_1-MLC",
model_id: "Phi-3.5-vision-instruct-q4f32_1-MLC",
model_lib:
modelLibURLPrefix +
modelVersion +
"/Phi-3.5-vision-instruct-q4f32_1-ctx4k_cs1k-webgpu.wasm",
vram_required_MB: 5879.84,
low_resource_required: true,
overrides: {
context_window_size: 4096,
},
model_type: ModelType.VLM,
},
// Mistral variants
{
model:
Expand Down
4 changes: 2 additions & 2 deletions src/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ export class Conversation {
*/
getPromptArrayLastRound() {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("getPromptyArrayLastRound");
throw new TextCompletionConversationError("getPromptArrayLastRound");
}
if (this.messages.length < 3) {
throw Error("needs to call getPromptArray for the first message");
Expand Down Expand Up @@ -346,7 +346,7 @@ export function getConversationFromChatCompletionRequest(
* encounter invalid request.
*
* @param request The chatCompletionRequest we are about to prefill for.
* @returns The string used to set Conversatoin.function_string
* @returns The string used to set Conversation.function_string
*/
export function getFunctionCallUsage(request: ChatCompletionRequest): string {
if (
Expand Down
19 changes: 18 additions & 1 deletion src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ export class MLCEngine implements MLCEngineInterface {
>;
/** Maps each loaded model's modelId to its chatConfig */
private loadedModelIdToChatConfig: Map<string, ChatConfig>;
/** Maps each loaded model's modelId to its modelType */
private loadedModelIdToModelType: Map<string, ModelType>;
/** Maps each loaded model's modelId to a lock. Ensures
* each model only processes one request at at time.
*/
Expand All @@ -141,6 +143,7 @@ export class MLCEngine implements MLCEngineInterface {
LLMChatPipeline | EmbeddingPipeline
>();
this.loadedModelIdToChatConfig = new Map<string, ChatConfig>();
this.loadedModelIdToModelType = new Map<string, ModelType>();
this.loadedModelIdToLock = new Map<string, CustomLock>();
this.appConfig = engineConfig?.appConfig || prebuiltAppConfig;
this.setLogLevel(engineConfig?.logLevel || DefaultLogLevel);
Expand Down Expand Up @@ -239,6 +242,7 @@ export class MLCEngine implements MLCEngineInterface {
const logitProcessor = this.logitProcessorRegistry?.get(modelId);
const tstart = performance.now();

// look up and parse model record, record model type
const modelRecord = findModelRecord(modelId, this.appConfig);
const baseUrl =
typeof document !== "undefined"
Expand All @@ -248,7 +252,13 @@ export class MLCEngine implements MLCEngineInterface {
if (!modelUrl.startsWith("http")) {
modelUrl = new URL(modelUrl, baseUrl).href;
}
const modelType =
modelRecord.model_type === undefined || modelRecord.model_type === null
? ModelType.LLM
: modelRecord.model_type;
this.loadedModelIdToModelType.set(modelId, modelType);

// instantiate cache
let configCache: tvmjs.ArtifactCacheTemplate;
if (this.appConfig.useIndexedDBCache) {
configCache = new tvmjs.ArtifactIndexedDBCache("webllm/config");
Expand Down Expand Up @@ -409,6 +419,7 @@ export class MLCEngine implements MLCEngineInterface {
}
this.loadedModelIdToPipeline.clear();
this.loadedModelIdToChatConfig.clear();
this.loadedModelIdToModelType.clear();
this.loadedModelIdToLock.clear();
this.deviceLostIsError = true;
if (this.reloadController) {
Expand Down Expand Up @@ -737,7 +748,13 @@ export class MLCEngine implements MLCEngineInterface {
// 0. Check model loaded and preprocess inputs
const [selectedModelId, selectedPipeline, selectedChatConfig] =
this.getLLMStates("ChatCompletionRequest", request.model);
API.postInitAndCheckFieldsChatCompletion(request, selectedModelId);
const selectedModelType =
this.loadedModelIdToModelType.get(selectedModelId);
API.postInitAndCheckFieldsChatCompletion(
request,
selectedModelId,
selectedModelType!,
);
const genConfig: GenerationConfig = {
frequency_penalty: request.frequency_penalty,
presence_penalty: request.presence_penalty,
Expand Down
36 changes: 28 additions & 8 deletions src/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,39 @@ export class ContentTypeError extends Error {
}
}

export class UserMessageContentError extends Error {
constructor(content: any) {
export class UnsupportedRoleError extends Error {
constructor(role: string) {
super(`Unsupported role of message: ${role}`);
this.name = "UnsupportedRoleError";
}
}

export class UserMessageContentErrorForNonVLM extends Error {
constructor(modelId: string, modelType: string, content: any) {
super(
`User message only supports string content for now, but received: ${content}`,
`The model loaded is not of type ModelType.VLM (vision-language model). ` +
`Therefore, user message only supports string content, but received: ${content}\n` +
`Loaded modelId: ${modelId}, modelType: ${modelType}`,
);
this.name = "UserMessageContentError";
this.name = "UserMessageContentErrorForNonVLM";
}
}

export class UnsupportedRoleError extends Error {
constructor(role: string) {
super(`Unsupported role of message: ${role}`);
this.name = "UnsupportedRoleError";
export class UnsupportedDetailError extends Error {
constructor(detail: string) {
super(
`Currently do not support field image_url.detail, but received: ${detail}`,
);
this.name = "UnsupportedDetailError";
}
}

export class MultipleTextContentError extends Error {
constructor(numTextContent: number) {
super(
`Each message can have at most one text contentPart, but received: ${numTextContent}`,
);
this.name = "MultipleTextContentError";
}
}

Expand Down
Loading

0 comments on commit fad3df9

Please sign in to comment.