Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Config] Enhance ModelRecord #435

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async function main() {
const label = document.getElementById("init-label");
label.innerText = report.text;
};
const selectedModel = "Llama-3-8B-Instruct-q4f32_1";
const selectedModel = "Llama-3-8B-Instruct-q4f32_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
/*engineConfig=*/ { initProgressCallback: initProgressCallback },
Expand Down Expand Up @@ -96,7 +96,7 @@ async function main() {
const initProgressCallback = (report) => {
console.log(report.text);
};
const selectedModel = "TinyLlama-1.1B-Chat-v0.4-q4f16_1-1k";
const selectedModel = "TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC-1k";
const engine = await webllm.CreateMLCEngine(selectedModel, {
initProgressCallback: initProgressCallback,
});
Expand Down Expand Up @@ -247,8 +247,8 @@ on how to add new model weights and libraries to WebLLM.

Here, we go over the high-level idea. There are two elements of the WebLLM package that enables new models and weight variants.

- `model_url`: Contains a URL to model artifacts, such as weights and meta-data.
- `model_lib_url`: A URL to the web assembly library (i.e. wasm file) that contains the executables to accelerate the model computations.
- `model`: Contains a URL to model artifacts, such as weights and meta-data.
- `model_lib`: A URL to the web assembly library (i.e. wasm file) that contains the executables to accelerate the model computations.

Both are customizable in the WebLLM.

Expand All @@ -257,9 +257,9 @@ async main() {
const appConfig = {
"model_list": [
{
"model_url": "/url/to/my/llama",
"model": "/url/to/my/llama",
"model_id": "MyLlama-3b-v1-q4f32_0"
"model_lib_url": "/url/to/myllama3b.wasm",
"model_lib": "/url/to/myllama3b.wasm",
}
],
};
Expand Down
15 changes: 10 additions & 5 deletions examples/cache-usage/src/cache_usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@ async function main() {
}

// 1. This triggers downloading and caching the model with either Cache or IndexedDB Cache
const selectedModel = "Phi2-q4f16_1"
const selectedModel = "phi-2-q4f16_1-MLC";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
"Phi2-q4f16_1",
{ initProgressCallback: initProgressCallback, appConfig: appConfig }
selectedModel,
{ initProgressCallback: initProgressCallback, appConfig: appConfig },
);

const request: webllm.ChatCompletionRequest = {
stream: false,
messages: [
{ "role": "user", "content": "Write an analogy between mathematics and a lighthouse." },
{
role: "user",
content: "Write an analogy between mathematics and a lighthouse.",
},
],
n: 1,
};
Expand All @@ -60,7 +63,9 @@ async function main() {
modelCached = await webllm.hasModelInCache(selectedModel, appConfig);
console.log("After deletion, hasModelInCache: ", modelCached);
if (modelCached) {
throw Error("Expect hasModelInCache() to be false, but got: " + modelCached);
throw Error(
"Expect hasModelInCache() to be false, but got: " + modelCached,
);
}

// 5. If we reload, we should expect the model to start downloading again
Expand Down
6 changes: 3 additions & 3 deletions examples/chrome-extension-webgpu-service-worker/src/popup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ const initProgressCallback = (report: InitProgressReport) => {
};

const engine: MLCEngineInterface = await CreateExtensionServiceWorkerMLCEngine(
"Mistral-7B-Instruct-v0.2-q4f16_1",
{ initProgressCallback: initProgressCallback }
"Mistral-7B-Instruct-v0.2-q4f16_1-MLC",
{ initProgressCallback: initProgressCallback },
);
const chatHistory: ChatCompletionMessageParam[] = [];

Expand Down Expand Up @@ -150,7 +150,7 @@ function updateAnswer(answer: string) {
function fetchPageContents() {
chrome.tabs.query({ currentWindow: true, active: true }, function (tabs) {
if (tabs[0]?.id) {
var port = chrome.tabs.connect(tabs[0].id, { name: "channelName" });
const port = chrome.tabs.connect(tabs[0].id, { name: "channelName" });
port.postMessage({});
port.onMessage.addListener(function (msg) {
console.log("Page contents:", msg.contents);
Expand Down
217 changes: 118 additions & 99 deletions examples/chrome-extension/src/popup.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
'use strict';
"use strict";

// This code is partially adapted from the openai-chatgpt-chrome-extension repo:
// https://github.com/jessedi0n/openai-chatgpt-chrome-extension

import './popup.css';
import "./popup.css";

import { MLCEngineInterface, InitProgressReport, CreateMLCEngine, ChatCompletionMessageParam } from "@mlc-ai/web-llm";
import {
MLCEngineInterface,
InitProgressReport,
CreateMLCEngine,
ChatCompletionMessageParam,
} from "@mlc-ai/web-llm";
import { ProgressBar, Line } from "progressbar.js";

const sleep = (ms: number) => new Promise((r) => setTimeout(r, ms));
Expand All @@ -21,135 +26,149 @@ fetchPageContents();

(<HTMLButtonElement>submitButton).disabled = true;

const progressBar: ProgressBar = new Line('#loadingContainer', {
strokeWidth: 4,
easing: 'easeInOut',
duration: 1400,
color: '#ffd166',
trailColor: '#eee',
trailWidth: 1,
svgStyle: { width: '100%', height: '100%' }
const progressBar: ProgressBar = new Line("#loadingContainer", {
strokeWidth: 4,
easing: "easeInOut",
duration: 1400,
color: "#ffd166",
trailColor: "#eee",
trailWidth: 1,
svgStyle: { width: "100%", height: "100%" },
});

const initProgressCallback = (report: InitProgressReport) => {
console.log(report.text, report.progress);
progressBar.animate(report.progress, {
duration: 50
});
if (report.progress == 1.0) {
enableInputs();
}
console.log(report.text, report.progress);
progressBar.animate(report.progress, {
duration: 50,
});
if (report.progress == 1.0) {
enableInputs();
}
};

// const selectedModel = "TinyLlama-1.1B-Chat-v0.4-q4f16_1-1k";
const selectedModel = "Mistral-7B-Instruct-v0.2-q4f16_1";
const engine: MLCEngineInterface = await CreateMLCEngine(
selectedModel,
{ initProgressCallback: initProgressCallback }
);
// const selectedModel = "TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC-1k";
const selectedModel = "Mistral-7B-Instruct-v0.2-q4f16_1-MLC";
const engine: MLCEngineInterface = await CreateMLCEngine(selectedModel, {
initProgressCallback: initProgressCallback,
});
const chatHistory: ChatCompletionMessageParam[] = [];

isLoadingParams = true;

function enableInputs() {
if (isLoadingParams) {
sleep(500);
(<HTMLButtonElement>submitButton).disabled = false;
const loadingBarContainer = document.getElementById("loadingContainer")!;
loadingBarContainer.remove();
queryInput.focus();
isLoadingParams = false;
}
if (isLoadingParams) {
sleep(500);
(<HTMLButtonElement>submitButton).disabled = false;
const loadingBarContainer = document.getElementById("loadingContainer")!;
loadingBarContainer.remove();
queryInput.focus();
isLoadingParams = false;
}
}

// Disable submit button if input field is empty
queryInput.addEventListener("keyup", () => {
if ((<HTMLInputElement>queryInput).value === "") {
(<HTMLButtonElement>submitButton).disabled = true;
} else {
(<HTMLButtonElement>submitButton).disabled = false;
}
if ((<HTMLInputElement>queryInput).value === "") {
(<HTMLButtonElement>submitButton).disabled = true;
} else {
(<HTMLButtonElement>submitButton).disabled = false;
}
});

// If user presses enter, click submit button
queryInput.addEventListener("keyup", (event) => {
if (event.code === "Enter") {
event.preventDefault();
submitButton.click();
}
if (event.code === "Enter") {
event.preventDefault();
submitButton.click();
}
});

// Listen for clicks on submit button
async function handleClick() {
// Get the message from the input field
const message = (<HTMLInputElement>queryInput).value;
console.log("message", message);
// Clear the answer
document.getElementById("answer")!.innerHTML = "";
// Hide the answer
document.getElementById("answerWrapper")!.style.display = "none";
// Show the loading indicator
document.getElementById("loading-indicator")!.style.display = "block";

// Generate response
let inp = message;
if (context.length > 0) {
inp = "Use only the following context when answering the question at the end. Don't use any other knowledge.\n" + context + "\n\nQuestion: " + message + "\n\nHelpful Answer: ";
}
console.log("Input:", inp);
chatHistory.push({ "role": "user", "content": inp });

let curMessage = "";
const completion = await engine.chat.completions.create({ stream: true, messages: chatHistory });
for await (const chunk of completion) {
const curDelta = chunk.choices[0].delta.content;
if (curDelta) {
curMessage += curDelta;
}
updateAnswer(curMessage);
// Get the message from the input field
const message = (<HTMLInputElement>queryInput).value;
console.log("message", message);
// Clear the answer
document.getElementById("answer")!.innerHTML = "";
// Hide the answer
document.getElementById("answerWrapper")!.style.display = "none";
// Show the loading indicator
document.getElementById("loading-indicator")!.style.display = "block";

// Generate response
let inp = message;
if (context.length > 0) {
inp =
"Use only the following context when answering the question at the end. Don't use any other knowledge.\n" +
context +
"\n\nQuestion: " +
message +
"\n\nHelpful Answer: ";
}
console.log("Input:", inp);
chatHistory.push({ role: "user", content: inp });

let curMessage = "";
const completion = await engine.chat.completions.create({
stream: true,
messages: chatHistory,
});
for await (const chunk of completion) {
const curDelta = chunk.choices[0].delta.content;
if (curDelta) {
curMessage += curDelta;
}
const response = await engine.getMessage();
chatHistory.push({ "role": "assistant", "content": await engine.getMessage() });
console.log("response", response);
updateAnswer(curMessage);
}
const response = await engine.getMessage();
chatHistory.push({ role: "assistant", content: await engine.getMessage() });
console.log("response", response);
}
submitButton.addEventListener("click", handleClick);

// Listen for messages from the background script
chrome.runtime.onMessage.addListener(({ answer, error }) => {
if (answer) {
updateAnswer(answer);
}
if (answer) {
updateAnswer(answer);
}
});

function updateAnswer(answer: string) {
// Show answer
document.getElementById("answerWrapper")!.style.display = "block";
const answerWithBreaks = answer.replace(/\n/g, '<br>');
document.getElementById("answer")!.innerHTML = answerWithBreaks;
// Add event listener to copy button
document.getElementById("copyAnswer")!.addEventListener("click", () => {
// Get the answer text
const answerText = answer;
// Copy the answer text to the clipboard
navigator.clipboard.writeText(answerText)
.then(() => console.log("Answer text copied to clipboard"))
.catch((err) => console.error("Could not copy text: ", err));
});
const options: Intl.DateTimeFormatOptions = { month: 'short', day: '2-digit', hour: '2-digit', minute: '2-digit', second: '2-digit' };
const time = new Date().toLocaleString('en-US', options);
// Update timestamp
document.getElementById("timestamp")!.innerText = time;
// Hide loading indicator
document.getElementById("loading-indicator")!.style.display = "none";
// Show answer
document.getElementById("answerWrapper")!.style.display = "block";
const answerWithBreaks = answer.replace(/\n/g, "<br>");
document.getElementById("answer")!.innerHTML = answerWithBreaks;
// Add event listener to copy button
document.getElementById("copyAnswer")!.addEventListener("click", () => {
// Get the answer text
const answerText = answer;
// Copy the answer text to the clipboard
navigator.clipboard
.writeText(answerText)
.then(() => console.log("Answer text copied to clipboard"))
.catch((err) => console.error("Could not copy text: ", err));
});
const options: Intl.DateTimeFormatOptions = {
month: "short",
day: "2-digit",
hour: "2-digit",
minute: "2-digit",
second: "2-digit",
};
const time = new Date().toLocaleString("en-US", options);
// Update timestamp
document.getElementById("timestamp")!.innerText = time;
// Hide loading indicator
document.getElementById("loading-indicator")!.style.display = "none";
}

function fetchPageContents() {
chrome.tabs.query({ currentWindow: true, active: true }, function (tabs) {
var port = chrome.tabs.connect(tabs[0].id, { name: "channelName" });
port.postMessage({});
port.onMessage.addListener(function (msg) {
console.log("Page contents:", msg.contents);
context = msg.contents
});
chrome.tabs.query({ currentWindow: true, active: true }, function (tabs) {
const port = chrome.tabs.connect(tabs[0].id, { name: "channelName" });
port.postMessage({});
port.onMessage.addListener(function (msg) {
console.log("Page contents:", msg.contents);
context = msg.contents;
});
});
}
Loading
Loading