Skip to content

Commit

Permalink
[Config] Enhance ModelRecord (#435)
Browse files Browse the repository at this point in the history
There are three changes to `ModelRecord` this PR brings:

### 1. Update model ids to match HF repo name
We rename `modelId` in `webllm.prebuiltAppConfig` to be the exact same
as the HF repo name. For most models, that means we simply append `-MLC`
to the `modelId`. For the low-context version of the model, we would
have `{HF-repo}-1k`, suggesting 1k context length.

As a result, we rename Phi2 and phi1.5 models since their `modelId` did
not match with the repo name
- `Phi2-q4f32_1` → `phi-2-q4f32_1-MLC`
- `Phi1.5-q4f16_1` → `phi-1_5-q4f16_1-MLC`

### 2. Rename `model_url` and `model_lib_url` to `model` and `model_lib`
To better match with other platforms of MLC-LLM (e.g. iOS, Android), we
rename the `ModelRecord` fields.

### 3. Remove `resolve/main` from `model` URL
Instead of
`"https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC/resolve/main/"`,
we now make it
`"https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC/"`; note
the trailing `/` will be appended by us if it is not there.

### Example
As an example, we would have:
```typescript
    {
      model: "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC",
      model_id: "Llama-3-8B-Instruct-q4f16_1-MLC",
      model_lib: "path/to/Llama-3-8B-Instruct-q4f16_1-ctx1k_cs1k-webgpu.wasm",
    },
```
instead of 
```typescript
    {
      model_url: "https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC/resolve/main/",
      model_id: "Llama-3-8B-Instruct-q4f16_1",
      model_lib_url: "path/to/Llama-3-8B-Instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm",
    },
```

---------

Co-authored-by: Nestor Qin <[email protected]>
  • Loading branch information
CharlieFRuan and Neet-Nestor authored May 30, 2024
1 parent c995caa commit 896b012
Show file tree
Hide file tree
Showing 24 changed files with 788 additions and 718 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async function main() {
const label = document.getElementById("init-label");
label.innerText = report.text;
};
const selectedModel = "Llama-3-8B-Instruct-q4f32_1";
const selectedModel = "Llama-3-8B-Instruct-q4f32_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
/*engineConfig=*/ { initProgressCallback: initProgressCallback },
Expand Down Expand Up @@ -96,7 +96,7 @@ async function main() {
const initProgressCallback = (report) => {
console.log(report.text);
};
const selectedModel = "TinyLlama-1.1B-Chat-v0.4-q4f16_1-1k";
const selectedModel = "TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC-1k";
const engine = await webllm.CreateMLCEngine(selectedModel, {
initProgressCallback: initProgressCallback,
});
Expand Down Expand Up @@ -247,8 +247,8 @@ on how to add new model weights and libraries to WebLLM.

Here, we go over the high-level idea. There are two elements of the WebLLM package that enables new models and weight variants.

- `model_url`: Contains a URL to model artifacts, such as weights and meta-data.
- `model_lib_url`: A URL to the web assembly library (i.e. wasm file) that contains the executables to accelerate the model computations.
- `model`: Contains a URL to model artifacts, such as weights and meta-data.
- `model_lib`: A URL to the web assembly library (i.e. wasm file) that contains the executables to accelerate the model computations.

Both are customizable in the WebLLM.

Expand All @@ -257,9 +257,9 @@ async main() {
const appConfig = {
"model_list": [
{
"model_url": "/url/to/my/llama",
"model": "/url/to/my/llama",
"model_id": "MyLlama-3b-v1-q4f32_0"
"model_lib_url": "/url/to/myllama3b.wasm",
"model_lib": "/url/to/myllama3b.wasm",
}
],
};
Expand Down
15 changes: 10 additions & 5 deletions examples/cache-usage/src/cache_usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@ async function main() {
}

// 1. This triggers downloading and caching the model with either Cache or IndexedDB Cache
const selectedModel = "Phi2-q4f16_1"
const selectedModel = "phi-2-q4f16_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
"Phi2-q4f16_1",
{ initProgressCallback: initProgressCallback, appConfig: appConfig }
selectedModel,
{ initProgressCallback: initProgressCallback, appConfig: appConfig },
);

const request: webllm.ChatCompletionRequest = {
stream: false,
messages: [
{ "role": "user", "content": "Write an analogy between mathematics and a lighthouse." },
{
role: "user",
content: "Write an analogy between mathematics and a lighthouse.",
},
],
n: 1,
};
Expand All @@ -60,7 +63,9 @@ async function main() {
modelCached = await webllm.hasModelInCache(selectedModel, appConfig);
console.log("After deletion, hasModelInCache: ", modelCached);
if (modelCached) {
throw Error("Expect hasModelInCache() to be false, but got: " + modelCached);
throw Error(
"Expect hasModelInCache() to be false, but got: " + modelCached,
);
}

// 5. If we reload, we should expect the model to start downloading again
Expand Down
6 changes: 3 additions & 3 deletions examples/chrome-extension-webgpu-service-worker/src/popup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ const initProgressCallback = (report: InitProgressReport) => {
};

const engine: MLCEngineInterface = await CreateExtensionServiceWorkerMLCEngine(
"Mistral-7B-Instruct-v0.2-q4f16_1",
{ initProgressCallback: initProgressCallback }
"Mistral-7B-Instruct-v0.2-q4f16_1-MLC",
{ initProgressCallback: initProgressCallback },
);
const chatHistory: ChatCompletionMessageParam[] = [];

Expand Down Expand Up @@ -150,7 +150,7 @@ function updateAnswer(answer: string) {
function fetchPageContents() {
chrome.tabs.query({ currentWindow: true, active: true }, function (tabs) {
if (tabs[0]?.id) {
var port = chrome.tabs.connect(tabs[0].id, { name: "channelName" });
const port = chrome.tabs.connect(tabs[0].id, { name: "channelName" });
port.postMessage({});
port.onMessage.addListener(function (msg) {
console.log("Page contents:", msg.contents);
Expand Down
217 changes: 118 additions & 99 deletions examples/chrome-extension/src/popup.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
'use strict';
"use strict";

// This code is partially adapted from the openai-chatgpt-chrome-extension repo:
// https://github.com/jessedi0n/openai-chatgpt-chrome-extension

import './popup.css';
import "./popup.css";

import { MLCEngineInterface, InitProgressReport, CreateMLCEngine, ChatCompletionMessageParam } from "@mlc-ai/web-llm";
import {
MLCEngineInterface,
InitProgressReport,
CreateMLCEngine,
ChatCompletionMessageParam,
} from "@mlc-ai/web-llm";
import { ProgressBar, Line } from "progressbar.js";

const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms));
Expand All @@ -21,135 +26,149 @@ fetchPageContents();

(<HTMLButtonElement>submitButton).disabled = true;

const progressBar: ProgressBar = new Line('#loadingContainer', {
strokeWidth: 4,
easing: 'easeInOut',
duration: 1400,
color: '#ffd166',
trailColor: '#eee',
trailWidth: 1,
svgStyle: { width: '100%', height: '100%' }
const progressBar: ProgressBar = new Line("#loadingContainer", {
strokeWidth: 4,
easing: "easeInOut",
duration: 1400,
color: "#ffd166",
trailColor: "#eee",
trailWidth: 1,
svgStyle: { width: "100%", height: "100%" },
});

const initProgressCallback = (report: InitProgressReport) => {
console.log(report.text, report.progress);
progressBar.animate(report.progress, {
duration: 50
});
if (report.progress == 1.0) {
enableInputs();
}
console.log(report.text, report.progress);
progressBar.animate(report.progress, {
duration: 50,
});
if (report.progress == 1.0) {
enableInputs();
}
};

// const selectedModel = "TinyLlama-1.1B-Chat-v0.4-q4f16_1-1k";
const selectedModel = "Mistral-7B-Instruct-v0.2-q4f16_1";
const engine: MLCEngineInterface = await CreateMLCEngine(
selectedModel,
{ initProgressCallback: initProgressCallback }
);
// const selectedModel = "TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC-1k";
const selectedModel = "Mistral-7B-Instruct-v0.2-q4f16_1-MLC";
const engine: MLCEngineInterface = await CreateMLCEngine(selectedModel, {
initProgressCallback: initProgressCallback,
});
const chatHistory: ChatCompletionMessageParam[] = [];

isLoadingParams = true;

function enableInputs() {
if (isLoadingParams) {
sleep(500);
(<HTMLButtonElement>submitButton).disabled = false;
const loadingBarContainer = document.getElementById("loadingContainer")!;
loadingBarContainer.remove();
queryInput.focus();
isLoadingParams = false;
}
if (isLoadingParams) {
sleep(500);
(<HTMLButtonElement>submitButton).disabled = false;
const loadingBarContainer = document.getElementById("loadingContainer")!;
loadingBarContainer.remove();
queryInput.focus();
isLoadingParams = false;
}
}

// Disable submit button if input field is empty
queryInput.addEventListener("keyup", () => {
if ((<HTMLInputElement>queryInput).value === "") {
(<HTMLButtonElement>submitButton).disabled = true;
} else {
(<HTMLButtonElement>submitButton).disabled = false;
}
if ((<HTMLInputElement>queryInput).value === "") {
(<HTMLButtonElement>submitButton).disabled = true;
} else {
(<HTMLButtonElement>submitButton).disabled = false;
}
});

// If user presses enter, click submit button
queryInput.addEventListener("keyup", (event) => {
if (event.code === "Enter") {
event.preventDefault();
submitButton.click();
}
if (event.code === "Enter") {
event.preventDefault();
submitButton.click();
}
});

// Listen for clicks on submit button
async function handleClick() {
// Get the message from the input field
const message = (<HTMLInputElement>queryInput).value;
console.log("message", message);
// Clear the answer
document.getElementById("answer")!.innerHTML = "";
// Hide the answer
document.getElementById("answerWrapper")!.style.display = "none";
// Show the loading indicator
document.getElementById("loading-indicator")!.style.display = "block";

// Generate response
let inp = message;
if (context.length > 0) {
inp = "Use only the following context when answering the question at the end. Don't use any other knowledge.\n" + context + "\n\nQuestion: " + message + "\n\nHelpful Answer: ";
}
console.log("Input:", inp);
chatHistory.push({ "role": "user", "content": inp });

let curMessage = "";
const completion = await engine.chat.completions.create({ stream: true, messages: chatHistory });
for await (const chunk of completion) {
const curDelta = chunk.choices[0].delta.content;
if (curDelta) {
curMessage += curDelta;
}
updateAnswer(curMessage);
// Get the message from the input field
const message = (<HTMLInputElement>queryInput).value;
console.log("message", message);
// Clear the answer
document.getElementById("answer")!.innerHTML = "";
// Hide the answer
document.getElementById("answerWrapper")!.style.display = "none";
// Show the loading indicator
document.getElementById("loading-indicator")!.style.display = "block";

// Generate response
let inp = message;
if (context.length > 0) {
inp =
"Use only the following context when answering the question at the end. Don't use any other knowledge.\n" +
context +
"\n\nQuestion: " +
message +
"\n\nHelpful Answer: ";
}
console.log("Input:", inp);
chatHistory.push({ role: "user", content: inp });

let curMessage = "";
const completion = await engine.chat.completions.create({
stream: true,
messages: chatHistory,
});
for await (const chunk of completion) {
const curDelta = chunk.choices[0].delta.content;
if (curDelta) {
curMessage += curDelta;
}
const response = await engine.getMessage();
chatHistory.push({ "role": "assistant", "content": await engine.getMessage() });
console.log("response", response);
updateAnswer(curMessage);
}
const response = await engine.getMessage();
chatHistory.push({ role: "assistant", content: await engine.getMessage() });
console.log("response", response);
}
submitButton.addEventListener("click", handleClick);

// Listen for messages from the background script
chrome.runtime.onMessage.addListener(({ answer, error }) => {
if (answer) {
updateAnswer(answer);
}
if (answer) {
updateAnswer(answer);
}
});

function updateAnswer(answer: string) {
// Show answer
document.getElementById("answerWrapper")!.style.display = "block";
const answerWithBreaks = answer.replace(/\n/g, '<br>');
document.getElementById("answer")!.innerHTML = answerWithBreaks;
// Add event listener to copy button
document.getElementById("copyAnswer")!.addEventListener("click", () => {
// Get the answer text
const answerText = answer;
// Copy the answer text to the clipboard
navigator.clipboard.writeText(answerText)
.then(() => console.log("Answer text copied to clipboard"))
.catch((err) => console.error("Could not copy text: ", err));
});
const options: Intl.DateTimeFormatOptions = { month: 'short', day: '2-digit', hour: '2-digit', minute: '2-digit', second: '2-digit' };
const time = new Date().toLocaleString('en-US', options);
// Update timestamp
document.getElementById("timestamp")!.innerText = time;
// Hide loading indicator
document.getElementById("loading-indicator")!.style.display = "none";
// Show answer
document.getElementById("answerWrapper")!.style.display = "block";
const answerWithBreaks = answer.replace(/\n/g, "<br>");
document.getElementById("answer")!.innerHTML = answerWithBreaks;
// Add event listener to copy button
document.getElementById("copyAnswer")!.addEventListener("click", () => {
// Get the answer text
const answerText = answer;
// Copy the answer text to the clipboard
navigator.clipboard
.writeText(answerText)
.then(() => console.log("Answer text copied to clipboard"))
.catch((err) => console.error("Could not copy text: ", err));
});
const options: Intl.DateTimeFormatOptions = {
month: "short",
day: "2-digit",
hour: "2-digit",
minute: "2-digit",
second: "2-digit",
};
const time = new Date().toLocaleString("en-US", options);
// Update timestamp
document.getElementById("timestamp")!.innerText = time;
// Hide loading indicator
document.getElementById("loading-indicator")!.style.display = "none";
}

function fetchPageContents() {
chrome.tabs.query({ currentWindow: true, active: true }, function (tabs) {
var port = chrome.tabs.connect(tabs[0].id, { name: "channelName" });
port.postMessage({});
port.onMessage.addListener(function (msg) {
console.log("Page contents:", msg.contents);
context = msg.contents
});
chrome.tabs.query({ currentWindow: true, active: true }, function (tabs) {
const port = chrome.tabs.connect(tabs[0].id, { name: "channelName" });
port.postMessage({});
port.onMessage.addListener(function (msg) {
console.log("Page contents:", msg.contents);
context = msg.contents;
});
});
}
Loading

0 comments on commit 896b012

Please sign in to comment.