Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API][Engine] Support loading multiple models in a single engine #542

Merged
merged 5 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Note that all examples below run in-browser and use WebGPU as a backend.
- [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache
- [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()`
- [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, and integration with `EmbeddingsInterface` and `MemoryVectorStore` of [Langchain.js](js.langchain.com)
- [multi-models](multi-models): demonstrates loading multiple models in a single engine concurrently

#### Advanced OpenAI API Capabilities

Expand Down
14 changes: 14 additions & 0 deletions examples/multi-models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# WebLLM Get Started App

This folder provides a minimum demo to show WebLLM API in a webapp setting.
To try it out, you can do the following steps under this folder

```bash
npm install
npm start
```

Note if you would like to hack WebLLM core package.
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
instruction in the project to build webllm locally. This option is only recommended
if you would like to hack WebLLM core package.
20 changes: 20 additions & 0 deletions examples/multi-models/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"name": "get-started",
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/multi_models.html --port 8888",
"build": "parcel build src/multi_models.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "file:../.."
}
}
23 changes: 23 additions & 0 deletions examples/multi-models/src/multi_models.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<!doctype html>
<html>
<script>
webLLMGlobal = {};
</script>
<body>
<h2>WebLLM Test Page</h2>
Open console to see output
<br />
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./multi_models.ts"></script>
</body>
</html>
76 changes: 76 additions & 0 deletions examples/multi-models/src/multi_models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import * as webllm from "@mlc-ai/web-llm";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

/**
* Chat completion (OpenAI style) with streaming, with two models in the pipeline.
*/
async function mainStreaming() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};
const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k";
const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k";

const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
[selectedModel1, selectedModel2],
{ initProgressCallback: initProgressCallback },
);

const request1: webllm.ChatCompletionRequest = {
stream: true,
stream_options: { include_usage: true },
messages: [
{ role: "user", content: "Provide me three US states." },
{ role: "assistant", content: "California, New York, Pennsylvania." },
{ role: "user", content: "Two more please!" },
],
model: selectedModel1, // without specifying it, error will throw due to ambiguity
};

const request2: webllm.ChatCompletionRequest = {
stream: true,
stream_options: { include_usage: true },
messages: [
{ role: "user", content: "Provide me three cities in NY." },
{ role: "assistant", content: "New York, Binghamton, Buffalo." },
{ role: "user", content: "Two more please!" },
],
model: selectedModel2, // without specifying it, error will throw due to ambiguity
};

const asyncChunkGenerator1 = await engine.chat.completions.create(request1);
let message = "";
for await (const chunk of asyncChunkGenerator1) {
console.log(chunk);
message += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label", message);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
const asyncChunkGenerator2 = await engine.chat.completions.create(request2);
message += "\n\n";
for await (const chunk of asyncChunkGenerator2) {
console.log(chunk);
message += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label", message);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}

// without specifying from which model to get message, error will throw due to ambiguity
console.log("Final message 1:\n", await engine.getMessage(selectedModel1));
console.log("Final message 2:\n", await engine.getMessage(selectedModel2));
}

mainStreaming();
4 changes: 4 additions & 0 deletions src/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ export class EmbeddingPipeline {
await this.device.sync();
}

async asyncLoadWebGPUPipelines() {
await this.tvm.asyncLoadWebGPUPipelines(this.vm.getInternalModule());
}

// Performance APIs below

/**
Expand Down
Loading
Loading