Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Embeddings][OpenAI] Support embeddings via engine.embeddings.create() #538

Merged
merged 2 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Note that all examples below run in-browser and use WebGPU as a backend.
- [next-simple-chat](next-simple-chat): a mininum and complete chat bot app with [Next.js](https://nextjs.org/).
- [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache
- [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()`
- [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, and integration with `EmbeddingsInterface` and `MemoryVectorStore` of [Langchain.js](js.langchain.com)

#### Advanced OpenAI API Capabilities

Expand Down
14 changes: 14 additions & 0 deletions examples/embeddings/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# WebLLM Get Started App

This folder provides a minimum demo to show WebLLM API in a webapp setting.
To try it out, you can do the following steps under this folder

```bash
npm install
npm start
```

Note if you would like to hack WebLLM core package.
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
instruction in the project to build webllm locally. This option is only recommended
if you would like to hack WebLLM core package.
21 changes: 21 additions & 0 deletions examples/embeddings/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"name": "embeddings-example",
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/embeddings.html --port 8885",
"build": "parcel build src/embeddings.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "file:../..",
"langchain": "0.2.15"
}
}
23 changes: 23 additions & 0 deletions examples/embeddings/src/embeddings.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<!doctype html>
<html>
<script>
webLLMGlobal = {};
</script>
<body>
<h2>WebLLM Test Page</h2>
Open console to see output
<br />
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./embeddings.ts"></script>
</body>
</html>
147 changes: 147 additions & 0 deletions examples/embeddings/src/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import * as webllm from "@mlc-ai/web-llm";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import type { EmbeddingsInterface } from "@langchain/core/embeddings";
import type { Document } from "@langchain/core/documents";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

// For integration with Langchain
class WebLLMEmbeddings implements EmbeddingsInterface {
engine: webllm.MLCEngineInterface;
constructor(engine: webllm.MLCEngineInterface) {
this.engine = engine;
}

async _embed(texts: string[]): Promise<number[][]> {
const reply = await this.engine.embeddings.create({ input: texts });
const result: number[][] = [];
for (let i = 0; i < texts.length; i++) {
result.push(reply.data[i].embedding);
}
return result;
}

async embedQuery(document: string): Promise<number[]> {
return this._embed([document]).then((embeddings) => embeddings[0]);
}

async embedDocuments(documents: string[]): Promise<number[][]> {
return this._embed(documents);
}
}

// Prepare inputs
const documents_og = ["The Data Cloud!", "Mexico City of Course!"];
const queries_og = ["what is snowflake?", "Where can I get the best tacos?"];
const documents: string[] = [];
const queries: string[] = [];
const query_prefix =
"Represent this sentence for searching relevant passages: ";
// Process according to Snowflake model
documents_og.forEach(function (item, index) {
documents[index] = `[CLS] ${item} [SEP]`;
});
queries_og.forEach(function (item, index) {
queries[index] = `[CLS] ${query_prefix}${item} [SEP]`;
});
console.log("Formatted documents: ", documents);
console.log("Formatted queries: ", queries);

// Using webllm's API
async function webllmAPI() {
// b4 means the max batch size is compiled as 4. That is, the model can process 4 inputs in a
// batch. If given more than 4, the model will forward multiple times. The larger the max batch
// size, the more memory it consumes.
// const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b32";
const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b4";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
logLevel: "INFO", // specify the log level
},
);

const docReply = await engine.embeddings.create({ input: documents });
console.log(docReply);
console.log(docReply.usage);

const queryReply = await engine.embeddings.create({ input: queries });
console.log(queryReply);
console.log(queryReply.usage);

// Calculate similarity (we use langchain here, but any method works)
const vectorStore = await MemoryVectorStore.fromExistingIndex(
new WebLLMEmbeddings(engine),
);
// See score
for (let i = 0; i < queries_og.length; i++) {
console.log(`Similarity with: ${queries_og[i]}`);
for (let j = 0; j < documents_og.length; j++) {
const similarity = vectorStore.similarity(
queryReply.data[i].embedding,
docReply.data[j].embedding,
);
console.log(`${documents_og[j]}: ${similarity}`);
}
}
}

// Alternatively, integrating with Langchain's API
async function langchainAPI() {
// b4 means the max batch size is compiled as 4. That is, the model can process 4 inputs in a
// batch. If given more than 4, the model will forward multiple times. The larger the max batch
// size, the more memory it consumes.
// const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b32";
const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b4";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
logLevel: "INFO", // specify the log level
},
);

const vectorStore = await MemoryVectorStore.fromExistingIndex(
new WebLLMEmbeddings(engine),
);
const document0: Document = {
pageContent: documents[0],
metadata: {},
};
const document1: Document = {
pageContent: documents[1],
metadata: {},
};
await vectorStore.addDocuments([document0, document1]);

const similaritySearchResults0 = await vectorStore.similaritySearch(
queries[0],
1,
);
for (const doc of similaritySearchResults0) {
console.log(`* ${doc.pageContent}`);
}

const similaritySearchResults1 = await vectorStore.similaritySearch(
queries[1],
1,
);
for (const doc of similaritySearchResults1) {
console.log(`* ${doc.pageContent}`);
}
}

// Select one to run
webllmAPI();
// langchainAPI();
55 changes: 52 additions & 3 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ export interface TokenizerInfo {
* Only these fields affect the conversation in runtime.
* i.e. The third part in https://llm.mlc.ai/docs/get_started/mlc_chat_config.html.
*
* This is initialized in `ChatModule.reload()` with the model's `mlc-chat-config.json`.
* This is initialized in `MLCEngine.reload()` with the model's `mlc-chat-config.json`.
*/
export interface ChatConfig {
// First three fields affect the entire conversation, i.e. used in `ChatModule.reload()`
// First three fields affect the entire conversation, i.e. used in `MLCEngine.reload()`
tokenizer_files: Array<string>;
tokenizer_info?: TokenizerInfo;
token_table_postproc_method?: string; // TODO: backward compatibility, remove soon
Expand Down Expand Up @@ -122,7 +122,7 @@ export interface MLCEngineConfig {
* We also support additional fields not present in `mlc-chat-config.json` due to OpenAI-like APIs.
*
* Note that all values are optional. If unspecified, we use whatever values in `ChatConfig`
* initialized during `ChatModule.reload()`.
* initialized during `MLCEngine.reload()`.
*/
export interface GenerationConfig {
// Only used in MLC
Expand Down Expand Up @@ -226,6 +226,11 @@ export function postInitAndCheckGenerationConfigValues(
}
}

export enum ModelType {
"LLM",
"embedding",
}

/**
* Information for a model.
* @param model: the huggingface link to download the model weights, accepting four formats:
Expand All @@ -241,6 +246,7 @@ export function postInitAndCheckGenerationConfigValues(
* @param low_resource_required: whether the model can run on limited devices (e.g. Android phone).
* @param buffer_size_required_bytes: required `maxStorageBufferBindingSize`, different for each device.
* @param required_features: feature needed to run this model (e.g. shader-f16).
* @param model_type: the intended usecase for the model, if unspecified, default to LLM.
*/
export interface ModelRecord {
model: string;
Expand All @@ -251,6 +257,7 @@ export interface ModelRecord {
low_resource_required?: boolean;
buffer_size_required_bytes?: number;
required_features?: Array<string>;
model_type?: ModelType;
}

/**
Expand Down Expand Up @@ -1514,5 +1521,47 @@ export const prebuiltAppConfig: AppConfig = {
context_window_size: 1024,
},
},
// Embedding models
// -b means max_batch_size this model allows. The smaller it is, the less memory the model consumes.
{
model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-m-q0f32-MLC",
model_id: "snowflake-arctic-embed-m-q0f32-MLC-b32",
model_lib:
modelLibURLPrefix +
modelVersion +
"/snowflake-arctic-embed-m-q0f32-ctx512_cs512_batch32-webgpu.wasm",
vram_required_MB: 1407.51,
model_type: ModelType.embedding,
},
{
model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-m-q0f32-MLC",
model_id: "snowflake-arctic-embed-m-q0f32-MLC-b4",
model_lib:
modelLibURLPrefix +
modelVersion +
"/snowflake-arctic-embed-m-q0f32-ctx512_cs512_batch4-webgpu.wasm",
vram_required_MB: 539.4,
model_type: ModelType.embedding,
},
{
model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-s-q0f32-MLC",
model_id: "snowflake-arctic-embed-s-q0f32-MLC-b32",
model_lib:
modelLibURLPrefix +
modelVersion +
"/snowflake-arctic-embed-s-q0f32-ctx512_cs512_batch32-webgpu.wasm",
vram_required_MB: 1022.82,
model_type: ModelType.embedding,
},
{
model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-s-q0f32-MLC",
model_id: "snowflake-arctic-embed-s-q0f32-MLC-b4",
model_lib:
modelLibURLPrefix +
modelVersion +
"/snowflake-arctic-embed-s-q0f32-ctx512_cs512_batch4-webgpu.wasm",
vram_required_MB: 238.71,
model_type: ModelType.embedding,
},
],
};
Loading
Loading