Skip to content

Commit

Permalink
[API][Engine] Support loading multiple models in a single engine (#542)
Browse files Browse the repository at this point in the history
This PR does not change any behavior of existingly supported
API/workflow, but only introduces new behaviors when multiple models are
loaded in a single engine. 

```typescript
  const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
    [selectedModel1, selectedModel2],
  );
```
  • Loading branch information
CharlieFRuan committed Aug 13, 2024
1 parent 4e018b9 commit 78fed78
Show file tree
Hide file tree
Showing 22 changed files with 987 additions and 323 deletions.
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Note that all examples below run in-browser and use WebGPU as a backend.
- [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache
- [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()`
- [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, and integration with `EmbeddingsInterface` and `MemoryVectorStore` of [Langchain.js](js.langchain.com)
- [multi-models](multi-models): demonstrates loading multiple models in a single engine concurrently

#### Advanced OpenAI API Capabilities

Expand Down
14 changes: 14 additions & 0 deletions examples/multi-models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# WebLLM Get Started App

This folder provides a minimum demo to show WebLLM API in a webapp setting.
To try it out, you can do the following steps under this folder

```bash
npm install
npm start
```

Note if you would like to hack WebLLM core package.
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
instruction in the project to build webllm locally. This option is only recommended
if you would like to hack WebLLM core package.
20 changes: 20 additions & 0 deletions examples/multi-models/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"name": "get-started",
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/multi_models.html --port 8888",
"build": "parcel build src/multi_models.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "file:../.."
}
}
23 changes: 23 additions & 0 deletions examples/multi-models/src/multi_models.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<!doctype html>
<html>
<script>
webLLMGlobal = {};
</script>
<body>
<h2>WebLLM Test Page</h2>
Open console to see output
<br />
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./multi_models.ts"></script>
</body>
</html>
76 changes: 76 additions & 0 deletions examples/multi-models/src/multi_models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import * as webllm from "@mlc-ai/web-llm";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

/**
* Chat completion (OpenAI style) with streaming, with two models in the pipeline.
*/
async function mainStreaming() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};
const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k";
const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k";

const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
[selectedModel1, selectedModel2],
{ initProgressCallback: initProgressCallback },
);

const request1: webllm.ChatCompletionRequest = {
stream: true,
stream_options: { include_usage: true },
messages: [
{ role: "user", content: "Provide me three US states." },
{ role: "assistant", content: "California, New York, Pennsylvania." },
{ role: "user", content: "Two more please!" },
],
model: selectedModel1, // without specifying it, error will throw due to ambiguity
};

const request2: webllm.ChatCompletionRequest = {
stream: true,
stream_options: { include_usage: true },
messages: [
{ role: "user", content: "Provide me three cities in NY." },
{ role: "assistant", content: "New York, Binghamton, Buffalo." },
{ role: "user", content: "Two more please!" },
],
model: selectedModel2, // without specifying it, error will throw due to ambiguity
};

const asyncChunkGenerator1 = await engine.chat.completions.create(request1);
let message = "";
for await (const chunk of asyncChunkGenerator1) {
console.log(chunk);
message += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label", message);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
const asyncChunkGenerator2 = await engine.chat.completions.create(request2);
message += "\n\n";
for await (const chunk of asyncChunkGenerator2) {
console.log(chunk);
message += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label", message);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}

// without specifying from which model to get message, error will throw due to ambiguity
console.log("Final message 1:\n", await engine.getMessage(selectedModel1));
console.log("Final message 2:\n", await engine.getMessage(selectedModel2));
}

mainStreaming();
4 changes: 4 additions & 0 deletions src/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ export class EmbeddingPipeline {
await this.device.sync();
}

async asyncLoadWebGPUPipelines() {
await this.tvm.asyncLoadWebGPUPipelines(this.vm.getInternalModule());
}

// Performance APIs below

/**
Expand Down
Loading

0 comments on commit 78fed78

Please sign in to comment.