Skip to content

Commit

Permalink
[Embeddings][OpenAI] Support embeddings via engine.embeddings.create()
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieFRuan committed Aug 12, 2024
1 parent 66dd646 commit ef27ebc
Show file tree
Hide file tree
Showing 17 changed files with 1,112 additions and 31 deletions.
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Note that all examples below run in-browser and use WebGPU as a backend.
- [next-simple-chat](next-simple-chat): a mininum and complete chat bot app with [Next.js](https://nextjs.org/).
- [multi-round-chat](multi-round-chat): while APIs are functional, we internally optimize so that multi round chat usage can reuse KV cache
- [text-completion](text-completion): demonstrates API `engine.completions.create()`, which is pure text completion with no conversation, as opposed to `engine.chat.completions.create()`
- [embeddings](embeddings): demonstrates API `engine.embeddings.create()`, and integration with `EmbeddingsInterface` and `MemoryVectorStore` of Langchain.js

#### Advanced OpenAI API Capabilities

Expand Down
14 changes: 14 additions & 0 deletions examples/embeddings/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# WebLLM Get Started App

This folder provides a minimum demo to show WebLLM API in a webapp setting.
To try it out, you can do the following steps under this folder

```bash
npm install
npm start
```

Note if you would like to hack WebLLM core package.
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
instruction in the project to build webllm locally. This option is only recommended
if you would like to hack WebLLM core package.
21 changes: 21 additions & 0 deletions examples/embeddings/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"name": "embeddings-example",
"version": "0.1.0",
"private": true,
"scripts": {
"start": "parcel src/embeddings.html --port 8885",
"build": "parcel build src/embeddings.html --dist-dir lib"
},
"devDependencies": {
"buffer": "^5.7.1",
"parcel": "^2.8.3",
"process": "^0.11.10",
"tslib": "^2.3.1",
"typescript": "^4.9.5",
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "file:../..",
"langchain": "0.2.15"
}
}
23 changes: 23 additions & 0 deletions examples/embeddings/src/embeddings.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<!doctype html>
<html>
<script>
webLLMGlobal = {};
</script>
<body>
<h2>WebLLM Test Page</h2>
Open console to see output
<br />
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./embeddings.ts"></script>
</body>
</html>
147 changes: 147 additions & 0 deletions examples/embeddings/src/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import * as webllm from "@mlc-ai/web-llm";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import type { EmbeddingsInterface } from "@langchain/core/embeddings";
import type { Document } from "@langchain/core/documents";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

// For integration with Langchain
class WebLLMEmbeddings implements EmbeddingsInterface {
engine: webllm.MLCEngineInterface;
constructor(engine: webllm.MLCEngineInterface) {
this.engine = engine;
}

async _embed(texts: string[]): Promise<number[][]> {
const reply = await this.engine.embeddings.create({ input: texts });
const result: number[][] = [];
for (let i = 0; i < texts.length; i++) {
result.push(reply.data[i].embedding);
}
return result;
}

async embedQuery(document: string): Promise<number[]> {
return this._embed([document]).then((embeddings) => embeddings[0]);
}

async embedDocuments(documents: string[]): Promise<number[][]> {
return this._embed(documents);
}
}

// Prepare inputs
const documents_og = ["The Data Cloud!", "Mexico City of Course!"];
const queries_og = ["what is snowflake?", "Where can I get the best tacos?"];
const documents: string[] = [];
const queries: string[] = [];
const query_prefix =
"Represent this sentence for searching relevant passages: ";
// Process according to Snowflake model
documents_og.forEach(function (item, index) {
documents[index] = `[CLS] ${item} [SEP]`;
});
queries_og.forEach(function (item, index) {
queries[index] = `[CLS] ${query_prefix}${item} [SEP]`;
});
console.log("Formatted documents: ", documents);
console.log("Formatted queries: ", queries);

// Using webllm's API
async function webllmAPI() {
// b4 means the max batch size is compiled as 4. That is, the model can process 4 inputs in a
// batch. If given more than 4, the model will forward multiple times. The larger the max batch
// size, the more memory it consumes.
// const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b32";
const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b4";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
logLevel: "INFO", // specify the log level
},
);

const docReply = await engine.embeddings.create({ input: documents });
console.log(docReply);
console.log(docReply.usage);

const queryReply = await engine.embeddings.create({ input: queries });
console.log(queryReply);
console.log(queryReply.usage);

// Calculate similarity (we use langchain here, but any method works)
const vectorStore = await MemoryVectorStore.fromExistingIndex(
new WebLLMEmbeddings(engine),
);
// See score
for (let i = 0; i < queries_og.length; i++) {
console.log(`Similarity with: ${queries_og[i]}`);
for (let j = 0; j < documents_og.length; j++) {
const similarity = vectorStore.similarity(
queryReply.data[i].embedding,
docReply.data[j].embedding,
);
console.log(`${documents_og[j]}: ${similarity}`);
}
}
}

// Alternatively, integrating with Langchain's API
async function langchainAPI() {
// b4 means the max batch size is compiled as 4. That is, the model can process 4 inputs in a
// batch. If given more than 4, the model will forward multiple times. The larger the max batch
// size, the more memory it consumes.
// const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b32";
const selectedModel = "snowflake-arctic-embed-m-q0f32-MLC-b4";
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
logLevel: "INFO", // specify the log level
},
);

const vectorStore = await MemoryVectorStore.fromExistingIndex(
new WebLLMEmbeddings(engine),
);
const document0: Document = {
pageContent: documents[0],
metadata: {},
};
const document1: Document = {
pageContent: documents[1],
metadata: {},
};
await vectorStore.addDocuments([document0, document1]);

const similaritySearchResults0 = await vectorStore.similaritySearch(
queries[0],
1,
);
for (const doc of similaritySearchResults0) {
console.log(`* ${doc.pageContent}`);
}

const similaritySearchResults1 = await vectorStore.similaritySearch(
queries[1],
1,
);
for (const doc of similaritySearchResults1) {
console.log(`* ${doc.pageContent}`);
}
}

// Select one to run
webllmAPI();
// langchainAPI();
55 changes: 52 additions & 3 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ export interface TokenizerInfo {
* Only these fields affect the conversation in runtime.
* i.e. The third part in https://llm.mlc.ai/docs/get_started/mlc_chat_config.html.
*
* This is initialized in `ChatModule.reload()` with the model's `mlc-chat-config.json`.
* This is initialized in `MLCEngine.reload()` with the model's `mlc-chat-config.json`.
*/
export interface ChatConfig {
// First three fields affect the entire conversation, i.e. used in `ChatModule.reload()`
// First three fields affect the entire conversation, i.e. used in `MLCEngine.reload()`
tokenizer_files: Array<string>;
tokenizer_info?: TokenizerInfo;
token_table_postproc_method?: string; // TODO: backward compatibility, remove soon
Expand Down Expand Up @@ -122,7 +122,7 @@ export interface MLCEngineConfig {
* We also support additional fields not present in `mlc-chat-config.json` due to OpenAI-like APIs.
*
* Note that all values are optional. If unspecified, we use whatever values in `ChatConfig`
* initialized during `ChatModule.reload()`.
* initialized during `MLCEngine.reload()`.
*/
export interface GenerationConfig {
// Only used in MLC
Expand Down Expand Up @@ -226,6 +226,11 @@ export function postInitAndCheckGenerationConfigValues(
}
}

export enum ModelType {
"LLM",
"embedding",
}

/**
* Information for a model.
* @param model: the huggingface link to download the model weights, accepting four formats:
Expand All @@ -241,6 +246,7 @@ export function postInitAndCheckGenerationConfigValues(
* @param low_resource_required: whether the model can run on limited devices (e.g. Android phone).
* @param buffer_size_required_bytes: required `maxStorageBufferBindingSize`, different for each device.
* @param required_features: feature needed to run this model (e.g. shader-f16).
* @param model_type: the intended usecase for the model, if unspecified, default to LLM.
*/
export interface ModelRecord {
model: string;
Expand All @@ -251,6 +257,7 @@ export interface ModelRecord {
low_resource_required?: boolean;
buffer_size_required_bytes?: number;
required_features?: Array<string>;
model_type?: ModelType;
}

/**
Expand Down Expand Up @@ -1514,5 +1521,47 @@ export const prebuiltAppConfig: AppConfig = {
context_window_size: 1024,
},
},
// Embedding models
// -b means max_batch_size this model allows. The smaller it is, the less memory the model consumes.
{
model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-m-q0f32-MLC",
model_id: "snowflake-arctic-embed-m-q0f32-MLC-b32",
model_lib:
modelLibURLPrefix +
modelVersion +
"/snowflake-arctic-embed-m-q0f32-ctx512_cs512_batch32-webgpu.wasm",
vram_required_MB: 1407.51,
model_type: ModelType.embedding,
},
{
model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-m-q0f32-MLC",
model_id: "snowflake-arctic-embed-m-q0f32-MLC-b4",
model_lib:
modelLibURLPrefix +
modelVersion +
"/snowflake-arctic-embed-m-q0f32-ctx512_cs512_batch4-webgpu.wasm",
vram_required_MB: 539.4,
model_type: ModelType.embedding,
},
{
model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-s-q0f32-MLC",
model_id: "snowflake-arctic-embed-s-q0f32-MLC-b32",
model_lib:
modelLibURLPrefix +
modelVersion +
"/snowflake-arctic-embed-s-q0f32-ctx512_cs512_batch32-webgpu.wasm",
vram_required_MB: 1022.82,
model_type: ModelType.embedding,
},
{
model: "https://huggingface.co/mlc-ai/snowflake-arctic-embed-s-q0f32-MLC",
model_id: "snowflake-arctic-embed-s-q0f32-MLC-b4",
model_lib:
modelLibURLPrefix +
modelVersion +
"/snowflake-arctic-embed-s-q0f32-ctx512_cs512_batch4-webgpu.wasm",
vram_required_MB: 238.71,
model_type: ModelType.embedding,
},
],
};
Loading

0 comments on commit ef27ebc

Please sign in to comment.